-
Notifications
You must be signed in to change notification settings - Fork 434
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Triton] Allow reorderValues to handle downcast with dot_op layout on…
… 16-bit -> 8-bit in the same way it handles 8-bit -> 16-bit. We already needed to do something similar for 16/32 bits previously. PiperOrigin-RevId: 689778145
- Loading branch information
1 parent
7f3fe5e
commit ee02496
Showing
2 changed files
with
37 additions
and
0 deletions.
There are no files selected for viewing
36 changes: 36 additions & 0 deletions
36
third_party/triton/temporary/further_mixed_precision_fix.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
This resolves the issue here b/372630230. The patch is not intended to be | ||
submitted to Triton upstream. This is because OAI historically refused these | ||
similar work-arounds and the proper fixes are considerably more expensive to do. | ||
diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | ||
--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | ||
+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | ||
@@ -55,7 +55,8 @@ SmallVector<Value> reorderValues(const S | ||
} | ||
return ret; | ||
} | ||
- if (inBitWidth == 8 && ouBitWidth == 16) { | ||
+ if ((inBitWidth == 8 && ouBitWidth == 16) || | ||
+ (inBitWidth == 16 && ouBitWidth == 8)) { | ||
SmallVector<Value> ret; | ||
for (unsigned i = 0; i < values.size(); i += 16) { | ||
ret.push_back(values[i + 0]); | ||
diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir | ||
--- a/test/Conversion/tritongpu_to_llvm.mlir | ||
+++ b/test/Conversion/tritongpu_to_llvm.mlir | ||
@@ -1693,3 +1693,16 @@ module attributes {"triton_gpu.num-ctas" | ||
tt.return | ||
} | ||
} | ||
+ | ||
+// ----- | ||
+ | ||
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> | ||
+#dot_operand = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=4}> | ||
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { | ||
+ tt.func @f16_to_f8_dot_operand(%f16_inp: tensor<32x32xf16, #dot_operand>) { | ||
+ // CHECK-LABEL: @f16_to_f8_dot_operand | ||
+ | ||
+ %f8 = tt.fp_to_fp %f16_inp, rounding = rtne : tensor<32x32xf16, #dot_operand> -> tensor<32x32xf8E5M2, #dot_operand> | ||
+ tt.return | ||
+ } | ||
+} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters