diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index c4344c3100..8a8509dc3a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -10,7 +10,6 @@ #include #include #include -#include // clang-format off // The fixed ordering of the headers is required for CUTLASS 3.2+ @@ -332,9 +331,10 @@ at::Tensor bf16bf16bf16_grouped_impl( auto stream = at::cuda::getCurrentCUDAStream().stream(); int64_t output_offset = 0; - if (zero_start_index_M.has_value() == true) { - TORCH_CHECK(zero_start_index_M.value().dtype() == torch::kInt32); - } + // If passed, zero_start_index_M must be tensor of int32 + TORCH_CHECK( + !zero_start_index_M.has_value() || + zero_start_index_M->dtype() == at::kInt); // Set arguments for (int i = 0; i < problem_count; ++i) {