From d49bc1760aa72f2f97ca7e13c4771649e6b00b74 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 27 Nov 2024 09:39:29 -0800 Subject: [PATCH] Fix CUDA-11.4 build issue Summary: `#include ` was introduced by D65260109 and somehow causes NVCC-11.4 ICE Reviewed By: xw285cornell Differential Revision: D66512512 --- .../quantize/cutlass_extensions/bf16bf16bf16_grouped.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index c4344c3100..8a8509dc3a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -10,7 +10,6 @@ #include #include #include -#include // clang-format off // The fixed ordering of the headers is required for CUTLASS 3.2+ @@ -332,9 +331,10 @@ at::Tensor bf16bf16bf16_grouped_impl( auto stream = at::cuda::getCurrentCUDAStream().stream(); int64_t output_offset = 0; - if (zero_start_index_M.has_value() == true) { - TORCH_CHECK(zero_start_index_M.value().dtype() == torch::kInt32); - } + // If passed, zero_start_index_M must be tensor of int32 + TORCH_CHECK( + !zero_start_index_M.has_value() || + zero_start_index_M->dtype() == at::kInt); // Set arguments for (int i = 0; i < problem_count; ++i) {