diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daaa..b70f0a836ebb5 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -26,7 +26,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const double k_scale, const double v_scale); + const double k_scale, const double v_scale, + const bool is_NHD = true); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1be806bbfa43c..0ae5dd478e9e7 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -207,14 +207,14 @@ template __global__ void reshape_and_cache_flash_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, - // head_size] - cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, - // head_size] + cache_t* __restrict__ key_cache, cache_t* __restrict__ value_cache, const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, - const float k_scale, const float v_scale) { + const float k_scale, const float v_scale, const bool is_NHD) { + // For key/value_cache layout: + // - NHD: [num_blocks, block_size, num_heads, head_size] + // - HND: [num_blocks, num_heads, block_size, head_size] const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -229,9 +229,12 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_key_value_idx = block_idx * block_stride + - block_offset * num_heads * head_size + - head_idx * head_size + head_offset; + const int64_t tgt_key_value_idx = + block_idx * block_stride + + (is_NHD ? block_offset * num_heads * head_size + head_idx * head_size + + head_offset + : head_idx * block_size * head_size + block_offset * head_size + + head_offset); scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { @@ -302,18 +305,19 @@ void reshape_and_cache( value_stride, num_heads, head_size, block_size, k_scale, v_scale); void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& - value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const double v_scale, const bool is_NHD) { + // For key/value_cache layout: + // - NHD: [num_blocks, block_size, num_heads, head_size] + // - HND: [num_blocks, num_heads, block_size, head_size] int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); - int block_size = key_cache.size(1); + int block_size = is_NHD ? key_cache.size(1) : key_cache.size(2); int key_stride = key.stride(0); int value_stride = value.stride(0); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 971a45d50ffa4..916c72e8b484e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -382,7 +382,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " float k_scale, float v_scale," + " bool is_NHD) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40550ed51e2c7..2fd901ecffaac 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -15,6 +15,7 @@ NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 120, 256] BLOCK_SIZES = [8, 16, 32] +CACHE_LAYOUTS = ["NHD", "HND"] # Arbitrary values for testing # don't make it too large. e.g. [1024, 36000] will OOM @@ -211,6 +212,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("kv_layout", CACHE_LAYOUTS) @torch.inference_mode() def test_reshape_and_cache_flash( kv_cache_factory_flashinfer, @@ -223,6 +225,7 @@ def test_reshape_and_cache_flash( seed: int, device: str, kv_cache_dtype: str, + kv_layout: str, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) @@ -233,7 +236,7 @@ def test_reshape_and_cache_flash( slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) - + is_NHD = kv_layout == "NHD" qkv = torch.randn(num_tokens, 3, num_heads, @@ -252,6 +255,7 @@ def test_reshape_and_cache_flash( kv_cache_dtype, dtype, device=device, + is_NHD=is_NHD, ) key_cache, value_cache = key_caches[0].contiguous( ), value_caches[0].contiguous() @@ -275,10 +279,11 @@ def test_reshape_and_cache_flash( # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, - k_scale, v_scale), + k_scale, v_scale, is_NHD), cond=(head_size == HEAD_SIZES[0])) ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, k_scale, v_scale) + slot_mapping, kv_cache_dtype, k_scale, v_scale, + is_NHD) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) @@ -300,8 +305,12 @@ def test_reshape_and_cache_flash( for i in range(num_tokens): block_idx = block_indicies_lst[i] block_offset = block_offsets_lst[i] - cloned_key_cache[block_idx, block_offset, :, :] = key[i] - cloned_value_cache[block_idx, block_offset, :, :] = value[i] + if is_NHD: + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + else: + cloned_key_cache[block_idx, :, block_offset, :] = key[i] + cloned_value_cache[block_idx, :, block_offset, :] = value[i] if kv_cache_dtype == "fp8": torch.testing.assert_close(result_key_cache, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 767d45ede7e87..f45ef1cf98053 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -876,11 +876,12 @@ def reshape_and_cache_flash( kv_cache_dtype: str, k_scale: float, v_scale: float, + is_NHD: bool = True, ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, - v_scale) + v_scale, is_NHD) def copy_blocks(key_caches: List[torch.Tensor],