diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 701795579..84822dba9 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -598,7 +598,7 @@ def quantize_fixed_nk(self, x, w): return ( x, w, - torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device), + torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device), output, ) @@ -622,7 +622,7 @@ def quantize(self, x, w): m_values = None return x, w, m_values, output - def compute(self, x, w, m_values, output): + def compute(self, x, w, m_values, _): return torch.ops.fbgemm.bf16bf16bf16_grouped( x, w, @@ -642,7 +642,7 @@ def name(self) -> str: @property def hip(self) -> bool: - return False + return True @property def cuda(self) -> bool: diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip new file mode 100644 index 000000000..8e5a18920 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip @@ -0,0 +1,376 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" +#include "kernels/bf16_grouped_kernel_manifest.h" + +namespace fbgemm_gpu { + +// Define useful types that are needed for various kernels. +using KernelArguments = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<0>; +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +GroupedKernel grouped_heuristic_dispatch(int M, int N, int K) { + // We use shape heuristics to find the best kernel. + // To do this, we divide by the size of M and find the best + // option within that grouping. + if (M <= 16) { + if (N < 8192 && K <= 8192) { + return bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; + } + if (K <= 8192) { + return bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + } + return bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2; + } + if (M <= 32) { + if (N < 8192 && K <= 8192) { + return bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + } + if (K <= 8192) { + return bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + } + return bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2; + } + if (M <= 64) { + return bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + if (M <= 128) { + if (N < 8192 && K <= 8192) { + return bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + if (M <= 256) { + return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + if (M <= 512) { + if (K <= 8192) { + return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + } + return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + // Default kernel for all other shapes. + return bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; +} + +__global__ void set_kernel_args_kernel( + KernelArguments* kernel_args, + ADataType* A, + BDataType* B, + CDataType* output, + int M, + int N, + int K) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // Each kernel annoyingly can only set the kernel args for one group. + // This could only be avoided with complicated memory management. + if (idx == 0) { + // Write kernel arguments directly to memory. + KernelArguments kernel_group_args = { + A, B, {}, output, M, N, K, K, K, {}, N}; + kernel_args[0] = kernel_group_args; + } +} + +void set_static_kernel_args( + at::Tensor kernel_args, + at::TensorList A, + at::TensorList B, + std::vector output) { + // Get current cuda stream. + auto stream = at::cuda::getCurrentHIPStream().stream(); + int group_count = A.size(); + // When group count is large, we can more efficiently initialize + // by doing host setup and a memcpy. This is only viable if cuda + // graphs arent being used. + if (group_count >= 16 && stream == 0) { + std::vector ggemm_kargs; + ggemm_kargs.reserve(group_count); + + // Iterate over inputs and get group information. + for (int i = 0; i < group_count; i++) { + int M = A[i].size(0); + int K = A[i].size(1); + int N = B[i].size(0); + KernelArguments group_args = { + reinterpret_cast(A[i].data_ptr()), + reinterpret_cast(B[i].data_ptr()), + {}, + reinterpret_cast(output[i].data_ptr()), + M, + N, + K, + K, + K, + {}, + N}; + ggemm_kargs.push_back(group_args); + } + // Copy data onto device. + hipMemcpy( + kernel_args.data_ptr(), // Destination + ggemm_kargs.data(), // Source + sizeof(KernelArguments) * group_count, // Number of bytes + hipMemcpyHostToDevice); // Copy Type + } else { + // We use the smallest reasonable block size since we effectively need only + // 1 thread. + int blockSize = 32; + int numBlocks = 1; + // Launch a kernel for each group to set kernel memory on device. + // Using multiple kernels this way allows us to support arbitrary M,N,K. + // For some reason, this approach is faster than using hipmemcpy. + for (int i = 0; i < group_count; i++) { + int M = A[i].size(0); + int K = A[i].size(1); + int N = B[i].size(0); + // Launch kernel to set kernel arguments. + set_kernel_args_kernel<<>>( + reinterpret_cast( + reinterpret_cast(kernel_args.data_ptr()) + + (i * sizeof(KernelArguments))), + reinterpret_cast(A[i].data_ptr()), + reinterpret_cast(B[i].data_ptr()), + reinterpret_cast(output[i].data_ptr()), + M, + N, + K); + } + } +} + +__global__ void set_kernel_args_fixed_nk_kernel( + KernelArguments* kernel_args, + ADataType* A, + BDataType* B, + CDataType* output, + int64_t* prepad_M, + int M, + int N, + int K, + int group_count) { + int group_idx = blockIdx.x * blockDim.x + threadIdx.x; + // Each thread is responsible for setting up the arguments for one group. + if (group_idx < group_count) { + // Compute offsets for this group. + int group_M = prepad_M[group_idx]; + KernelArguments kernel_group_args = { + A + (group_idx * M * K), + B + (group_idx * N * K), + {}, + output + (group_idx * M * N), + group_M, + N, + K, + K, + K, + {}, + N}; + // Write kernel args to memory. + kernel_args[group_idx] = kernel_group_args; + } +} + +void set_dynamic_kernel_args( + at::Tensor kernel_args, + at::TensorList A, + at::TensorList B, + std::vector output, + at::Tensor zero_start_index_M) { + // Get current cuda stream. + auto stream = at::cuda::getCurrentHIPStream().stream(); + int group_count = A.size(); + // Confirm M is on the proper device. + TORCH_CHECK( + A[0].device() == zero_start_index_M.device(), + "zero_start_index_M and inputs must be on the same device."); + TORCH_CHECK( + zero_start_index_M.size(0) == group_count, + "zero_start_index_M must have an entry for each group."); + TORCH_CHECK( + zero_start_index_M.dtype() == at::kLong, + "zero_start_index_M must be int64."); + + // We assume that M, N, and K are fixed across groups. + // The actual m values are sstored in the passed M tensor. + int M = A[0].size(0); + int K = A[0].size(1); + int N = B[0].size(0); + + // Make sure that inputs are allocated in sequential memory as required by + // this mode. + for (int i = 1; i < group_count; i++) { + // Check that all inputs are allocated directly following preceding input. + TORCH_CHECK( + A[i].data_ptr() == + (reinterpret_cast(A[i - 1].data_ptr()) + (M * K)), + "Inputs must be sequential in memory to support dynamic M, but XQ is not."); + TORCH_CHECK( + B[i].data_ptr() == + (reinterpret_cast(B[i - 1].data_ptr()) + (N * K)), + "Inputs must be sequential in memory to support dynamic M, but WQ is not."); + TORCH_CHECK( + output[i].data_ptr() == + (reinterpret_cast(output[i - 1].data_ptr()) + (M * N)), + "Inputs must be sequential in memory to support dynamic M, but output is not."); + } + + // Launch a kernel that sets kernel argument memory. + int const blockSize = std::min(1024, group_count); + int const numBlocks = (group_count + blockSize - 1) / blockSize; + set_kernel_args_fixed_nk_kernel<<>>( + reinterpret_cast(kernel_args.data_ptr()), + reinterpret_cast(A[0].data_ptr()), + reinterpret_cast(B[0].data_ptr()), + reinterpret_cast(output[0].data_ptr()), + reinterpret_cast(zero_start_index_M.data_ptr()), + M, + N, + K, + group_count); +} + +at::Tensor get_grouped_kernel_args( + at::TensorList A, + at::TensorList B, + std::optional zero_start_index_M, + std::vector output) { + int group_count = A.size(); + // Get space on device for the kernel argument tensor. + at::Tensor kernel_args = at::empty( + {static_cast(group_count * sizeof(KernelArguments))}, + A[0].options().dtype(at::kByte)); + + // There are two different modes for this kernel. + // When zero_start_index_M is provided, we assume that data is sequential and + // that N and K are constants. This allows a more efficient kernel + // launch and is best suited to MOE use cases where M is truly dynamic. + // When zero_start_index_M is not provided, we assume M, N, and K can all vary + // and set them for each group. It is important to note that this does not + // work well with cuda graphs and runtime dynamism so if possible we recommend + // using zero_start_index_M. + + if (zero_start_index_M.has_value()) { + set_dynamic_kernel_args( + kernel_args, + A, + B, + output, + zero_start_index_M.value()); + } else { + set_static_kernel_args(kernel_args, A, B, output); + } + return kernel_args; +} + +std::vector bf16bf16bf16_grouped( + at::TensorList A, + at::TensorList B, + std::optional zero_start_index_M = std::nullopt, + std::optional> output = std::nullopt) { + // Check that input datatypes are valid. + // First confirm that there are the same number of groups in all inputs. + TORCH_CHECK( + A.size() == B.size(), + "A and B must have the same number of groups."); + int group_count = A.size(); + // Iterate over inputs and check they are valid. + for (at::Tensor a : A) { + TORCH_CHECK(a.is_cuda() && a.is_contiguous()); + TORCH_CHECK(a.dim() == 2, "Inputs must be 2D."); + TORCH_CHECK( + a.dtype() == at::kBFloat16, + "Inputs must be type bfloat16."); + } + for (at::Tensor b : B) { + TORCH_CHECK(b.is_cuda() && b.is_contiguous()); + TORCH_CHECK(b.dim() == 2, "Inputs must be 2D."); + TORCH_CHECK( + b.dtype() == at::kBFloat16, + "Inputs must be type bfloat16."); + } + + std::vector Y; + if (output.has_value()) { + Y = output.value(); + TORCH_CHECK( + Y.size() == group_count, + "Output and input must have same number of groups."); + // Check that output shapes are correct. + for (int i = 0; i < group_count; i++) { + int M = A[i].size(0); + int N = B[i].size(0); + int out_M = Y[i].size(0); + int out_N = Y[i].size(1); + TORCH_CHECK( + M == out_M && N == out_N, + "Output tensors do not have the expected shape."); + TORCH_CHECK( + Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); + } + } else { + // Two modes for allocating output. When m_values is provided, we need + // the output tensor to be contiguous and can assume M, N, and K are the + // same across groups. Otherwise, we can allocate each output separately. + if (zero_start_index_M.has_value()) { + int M = A[0].size(0); + int N = B[0].size(0); + // Fill output with zeros to simplify integration. This prevents nans from + // showing up in the tensor. + at::Tensor Y_full = + at::zeros({group_count, M, N}, A[0].options().dtype(at::kBFloat16)); + // Split the output into groups. + Y = at::unbind(Y_full, 0); + } else { + for (int i = 0; i < group_count; i++) { + int M = A[i].size(0); + int N = B[i].size(0); + Y.push_back(at::empty({M, N}, A[i].options().dtype(at::kBFloat16))); + } + } + } + + // Prepare kernel arguments by copying them to the proper device location. + at::Tensor kernel_args = get_grouped_kernel_args( + A, B, zero_start_index_M, Y); + + // Perform shape lookup to find best kernel. + // We use the largest of each shape for heuristics. + int MaxM = 0; + int MaxN = 0; + int MaxK = 0; + for (int i = 0; i < group_count; i++) { + MaxM = max(MaxM, A[i].size(0)); + MaxN = max(MaxN, B[i].size(0)); + MaxK = max(MaxK, A[i].size(1)); + } + GroupedKernel selected_kernel = + grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + return selected_kernel(A, B, kernel_args, Y); +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..f69f46195 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..9bc34a4ea --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..dfc9e2332 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -0,0 +1,38 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // A kernel that works well on small but not super tiny shapes. + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..0a2e7a402 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..568ccedda --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..544786751 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..3eda5cf17 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..aaf18bb3d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..9c4c226a2 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..23e4ab31f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..5f053385b --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..7df2b7230 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..8494eecbc --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..52f98714d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..61e827e72 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..1938e7a17 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..05e130f80 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..6cb887c9e --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..c2c6c4a16 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..1065b7c8d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..69f3e7755 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..64da7384a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..f34400ad4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..99ec733c4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..81cc8455c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..65225133e --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..12f1e93c6 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..897d0f878 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..a45da6f71 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..7aca24055 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..a7248db0a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..ec4fc3a8a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..a3393f848 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..10f01577f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 000000000..793942219 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip new file mode 100644 index 000000000..d3e582f2d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..a433ebed9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 000000000..91de4d86d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..4c9bf610c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip new file mode 100644 index 000000000..b75e04cf5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 000000000..2a4c54cdb --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..dbbb3ab47 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..4c040b744 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..5e14be33f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..b1911b48b --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..64a71fda9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Secret kernel that seems good with small M but large N and K. + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<16, 4, 1>, + S<16, 4, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..038c62f97 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..d1e7a69b1 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..01dde8baf --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..4c7198eca --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..900392a82 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,39 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<32, 2, 1>, + S<32, 2, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..8e537a96d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..ceaa1bd00 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..f29b540e7 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..8201a3a11 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..088117342 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..e050d0461 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16_grouped_common.h" + +std::vector +bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < A.size(); i++) { + int K = A[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return bf16_grouped_impl( + A, B, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_common.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_common.h new file mode 100644 index 000000000..05aa80b6c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_common.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#ifdef USE_ROCM +#include +#else +#include +#endif +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" + +// Define commonly used types. +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +using ComputeType = ck::bhalf_t; + +template < + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int WAVE_TILE_M, + int WAVE_TILE_N, + int WAVE_MAP_M, + int WAVE_MAP_N, + typename ABLOCK_TRANSFER, + typename BBLOCK_TRANSFER, + typename CBLOCK_TRANSFER, + typename CBLOCK_SPV, + int CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, + int CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + ck::BlockGemmPipelineScheduler LOOP_SCHED, + ck::BlockGemmPipelineVersion PIPELINE_VERSION, + ck::tensor_operation::device::GemmSpecialization GEMM_SPEC = + ck::tensor_operation::device::GemmSpecialization::MNPadding> +using DeviceGemmHelper = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementOp, + BElementOp, + CDEElementOp, + GEMM_SPEC, + 1, // NumGemmK + BLOCK_SIZE, // Block Size + MBLOCK, // M per Block + NBLOCK, // N per Block + KBLOCK, // K per Block + 8, // AK1 + 8, // BK1 + WAVE_TILE_M, // M per Xdl + WAVE_TILE_N, // N per Xdl + WAVE_MAP_M, // Mxdl per Wave + WAVE_MAP_N, // Nxdl per Wave + ABLOCK_TRANSFER, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + BBLOCK_TRANSFER, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, + CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + CBLOCK_TRANSFER, + CBLOCK_SPV, + LOOP_SCHED, + PIPELINE_VERSION, + ComputeType>; + +template +std::vector bf16_grouped_impl( + at::TensorList A, + at::TensorList B, + at::Tensor kernel_args, + std::vector Y) { + // Get input information. + int group_count = A.size(); + using KernelArguments = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<0>; + using GemmDesc = ck::tensor_operation::device::GemmDesc; + // Create gemm shape containers. + std::vector gemm_descs; + // Create container for input arguments. + std::vector A_args; + std::vector B_args; + std::vector C_args; + std::vector> D_args = {}; + // Reserve space in argument arrays. + gemm_descs.reserve(group_count); + A_args.reserve(group_count); + B_args.reserve(group_count); + C_args.reserve(group_count); + // Populate arguments. + for (int i = 0; i < group_count; i++) { + // Set the shape arguments for this gemm. + int M = A[i].size(0); + int K = A[i].size(1); + int N = B[i].size(0); + GemmDesc gemm_desc = {M, N, K, K, K, N, {}}; + gemm_descs.push_back(gemm_desc); + // Set pointers to inputs and outputs. + A_args.push_back(reinterpret_cast(A[i].data_ptr())); + B_args.push_back(reinterpret_cast(B[i].data_ptr())); + C_args.push_back(reinterpret_cast(Y[i].data_ptr())); + } + + // Create gemm launcher and arguments. + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + // Setup Gemm arguments. + auto argument = gemm.MakeArgument( + A_args, + B_args, + D_args, + C_args, + gemm_descs, + a_element_op, + b_element_op, + cde_element_op); + + // Set gemm kernel arguments. + gemm.SetDeviceKernelArgs(argument, kernel_args.data_ptr()); + + // Get hip graph stream if it exists. + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); + + return Y; +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_kernel_manifest.h new file mode 100644 index 000000000..8fe55110e --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/kernels/bf16_grouped_kernel_manifest.h @@ -0,0 +1,548 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#define KERNEL_NAME_MAP_ENTRY(name) \ + { #name, name } + +using GroupedKernel = std::function( + at::TensorList, + at::TensorList, + at::Tensor, + std::vector)>; + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +std::vector +bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::Tensor kernel_args, + std::vector Y); + +// Map function for string name to kernel implementation for manual +// specification. +static const std::unordered_map kernel_name_map = { + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + bf16_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), +}; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index 012407bba..803681b14 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -522,7 +522,9 @@ std::vector dispatch_bf16_grouped_kernel( std::vector bf16bf16bf16_grouped( at::TensorList x_group, // BF16 at::TensorList w_group, // BF16 - std::optional zero_start_index_M) { + std::optional zero_start_index_M = std::nullopt, + std::optional> output = std::nullopt) { + TORCH_CHECK(!output.has_value(), "Preallocated output not yet supported."); return dispatch_bf16_grouped_kernel(x_group, w_group, zero_start_index_M); } @@ -531,7 +533,8 @@ std::vector bf16bf16bf16_grouped( std::vector bf16bf16bf16_grouped( at::TensorList /* x_group */, // BF16 at::TensorList /* w_group */, // BF16 - std::optional /* zero_start_index_M */) { + std::optional /* zero_start_index_M */, + std::optional> /* output */) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 7262147b2..2039fd091 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -64,7 +64,8 @@ std::vector f8f8bf16_grouped( std::vector bf16bf16bf16_grouped( at::TensorList X, at::TensorList W, - std::optional zero_start_index_M); + std::optional zero_start_index_M = std::nullopt, + std::optional> output = std::nullopt); at::Tensor f8f8bf16_rowwise( at::Tensor XQ, at::Tensor WQ, @@ -177,8 +178,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( "f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]"); - m.def( - "bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor? zero_start_index_M=None) -> Tensor[]"); m.def( "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( @@ -195,6 +194,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "get_f8f8bf16_rowwise_grouped_kernels", get_f8f8bf16_rowwise_grouped_kernels); #endif + m.def( + "bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor? zero_start_index_M=None, Tensor[](a!)? output=None) -> Tensor[]"); m.def( "f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor"); m.def( @@ -246,12 +247,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); m.impl("quantize_fp8_per_col", quantize_fp8_per_col); + m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("f8f8bf16_grouped", f8f8bf16_grouped); - m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); @@ -271,12 +272,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); m.impl("quantize_fp8_per_col", quantize_fp8_per_col); + m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("f8f8bf16_grouped", f8f8bf16_grouped); - m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); @@ -473,7 +474,8 @@ std::vector f8f8bf16_grouped_meta( std::vector bf16bf16bf16_grouped_meta( at::TensorList X, at::TensorList W, - std::optional /* zero_start_index_M = std::nullopt */ + std::optional /* zero_start_index_M = std::nullopt */, + std::optional> /* output = std::nullopt */ ) { std::vector Y; for (int i = 0; i < X.size(); i++) { @@ -492,6 +494,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor_meta); m.impl("quantize_fp8_per_row", quantize_fp8_per_row_meta); m.impl("quantize_fp8_per_col", quantize_fp8_per_col_meta); + m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped_meta); #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16", f8f8bf16_meta); @@ -501,7 +504,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); m.impl("f8f8bf16_grouped", f8f8bf16_grouped_meta); - m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped_meta); #endif }