Skip to content

Commit

Permalink
[Triton] Allow reorderValues to handle downcast with dot_op layout on…
Browse files Browse the repository at this point in the history
… 16-bit -> 8-bit in the same way it handles 8-bit -> 16-bit. We already needed to do something similar for 16/32 bits previously.

PiperOrigin-RevId: 689778145
  • Loading branch information
Moerafaat authored and Google-ML-Automation committed Oct 25, 2024
1 parent 7f3fe5e commit ee02496
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
36 changes: 36 additions & 0 deletions third_party/triton/temporary/further_mixed_precision_fix.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
This resolves the issue here b/372630230. The patch is not intended to be
submitted to Triton upstream. This is because OAI historically refused these
similar work-arounds and the proper fixes are considerably more expensive to do.
diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -55,7 +55,8 @@ SmallVector<Value> reorderValues(const S
}
return ret;
}
- if (inBitWidth == 8 && ouBitWidth == 16) {
+ if ((inBitWidth == 8 && ouBitWidth == 16) ||
+ (inBitWidth == 16 && ouBitWidth == 8)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i + 0]);
diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir
--- a/test/Conversion/tritongpu_to_llvm.mlir
+++ b/test/Conversion/tritongpu_to_llvm.mlir
@@ -1693,3 +1693,16 @@ module attributes {"triton_gpu.num-ctas"
tt.return
}
}
+
+// -----
+
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
+#dot_operand = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=4}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
+ tt.func @f16_to_f8_dot_operand(%f16_inp: tensor<32x32xf16, #dot_operand>) {
+ // CHECK-LABEL: @f16_to_f8_dot_operand
+
+ %f8 = tt.fp_to_fp %f16_inp, rounding = rtne : tensor<32x32xf16, #dot_operand> -> tensor<32x32xf8E5M2, #dot_operand>
+ tt.return
+ }
+}
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ temporary_patch_list = [
"//third_party/triton:temporary/fix_left_shift_overflow.patch",
"//third_party/triton:temporary/prefetch.patch",
"//third_party/triton:temporary/i4_to_bf16.patch",
"//third_party/triton:temporary/further_mixed_precision_fix.patch",
# Add new patches just above this line
]

0 comments on commit ee02496

Please sign in to comment.