diff --git a/third_party/triton/temporary/further_mixed_precision_fix.patch b/third_party/triton/temporary/further_mixed_precision_fix.patch new file mode 100644 index 0000000000000..6152ab48194c0 --- /dev/null +++ b/third_party/triton/temporary/further_mixed_precision_fix.patch @@ -0,0 +1,36 @@ +This resolves the issue here b/372630230. The patch is not intended to be +submitted to Triton upstream. This is because OAI historically refused these +similar work-arounds and the proper fixes are considerably more expensive to do. +diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -55,7 +55,8 @@ SmallVector reorderValues(const S + } + return ret; + } +- if (inBitWidth == 8 && ouBitWidth == 16) { ++ if ((inBitWidth == 8 && ouBitWidth == 16) || ++ (inBitWidth == 16 && ouBitWidth == 8)) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); +diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir +--- a/test/Conversion/tritongpu_to_llvm.mlir ++++ b/test/Conversion/tritongpu_to_llvm.mlir +@@ -1693,3 +1693,16 @@ module attributes {"triton_gpu.num-ctas" + tt.return + } + } ++ ++// ----- ++ ++#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> ++#dot_operand = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=4}> ++module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { ++ tt.func @f16_to_f8_dot_operand(%f16_inp: tensor<32x32xf16, #dot_operand>) { ++ // CHECK-LABEL: @f16_to_f8_dot_operand ++ ++ %f8 = tt.fp_to_fp %f16_inp, rounding = rtne : tensor<32x32xf16, #dot_operand> -> tensor<32x32xf8E5M2, #dot_operand> ++ tt.return ++ } ++} diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 274faf600e048..d6b1a6a31f783 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -17,5 +17,6 @@ temporary_patch_list = [ "//third_party/triton:temporary/fix_left_shift_overflow.patch", "//third_party/triton:temporary/prefetch.patch", "//third_party/triton:temporary/i4_to_bf16.patch", + "//third_party/triton:temporary/further_mixed_precision_fix.patch", # Add new patches just above this line ]