diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index e5c1a53f34bf64..8a1ef94c853a58 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -552,7 +552,7 @@ def MFMAOutTypes : AnyTypeOf<[F64, VectorOfLengthAndType<[4, 16, 32], [I32]>, VectorOfLengthAndType<[4], [F64]>]>; // wmma -def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>; +def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>; def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, VectorOfLengthAndType<[8, 16], [F16, BF16]>]>; diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index e832dfa9d6b80e..35fd8270ca6935 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -329,13 +329,16 @@ class ROCDL_Wmma_IntrOp overloadedOperands, "$args attr-dict `:` functional-type($args, $res)"; } -// Available on RDNA3 +// Available from gfx11 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>; def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>; def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>; def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>; def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>; def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>; +// Available from gfx12 +def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; +def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; //===---------------------------------------------------------------------===// // Operations on raw buffer resources (stride of 0, bounds checks either off or in diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 809e9448e80abf..7e407f1ca528d8 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, + Value mlirInput, SmallVector &operands) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast(inputType); @@ -398,23 +399,29 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, return; } + // We need to check the type of the input before conversion to properly test + // for int8. This is because, in LLVM, fp8 type is converted to int8, so the + // fp8/int8 information is lost during the conversion process. + auto mlirInputType = cast(mlirInput.getType()); + bool isInputInt8 = mlirInputType.getElementType().isInteger(8); + if (isInputInt8) { + // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag + bool localIsUnsigned = isUnsigned; + if (elemType.isUnsignedInteger(8)) { + localIsUnsigned = true; + } else if (elemType.isSignedInteger(8)) { + localIsUnsigned = false; + } + Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); + operands.push_back(sign); + } + int64_t numBytes = vectorType.getNumElements(); Type i32 = rewriter.getI32Type(); VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); - Value result = rewriter.createOrFold( loc, llvmVectorType32bits, llvmInput); - - // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag - bool localIsUnsigned = isUnsigned; - if (elemType.isUnsignedInteger(8)) { - localIsUnsigned = true; - } else if (elemType.isSignedInteger(8)) { - localIsUnsigned = false; - } - Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); - operands.push_back(sign); operands.push_back(result); } @@ -590,18 +597,20 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, auto elemSourceType = sourceVectorType.getElementType(); auto elemDestType = destVectorType.getElementType(); - if (elemSourceType.isF16() && elemDestType.isF32()) { + if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); - } - if (elemSourceType.isBF16() && elemDestType.isF32()) { + if (elemSourceType.isBF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); - } else if (elemSourceType.isF16() && elemDestType.isF16()) { + if (elemSourceType.isF16() && elemDestType.isF16()) return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); - } else if (elemSourceType.isBF16() && elemDestType.isBF16()) { + if (elemSourceType.isBF16() && elemDestType.isBF16()) return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); - } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) { + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - } + if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_fp8::getOperationName(); + if (elemSourceType.isFloat8E5M2() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_bf8::getOperationName(); return std::nullopt; } @@ -662,8 +671,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { Location loc = op.getLoc(); Type outType = typeConverter->convertType(op.getDestD().getType()); - if (chipset.majorVersion != 11) - return op->emitOpError("WMMA only supported on gfx11"); + if (chipset.majorVersion != 11 && chipset.majorVersion != 12) + return op->emitOpError("WMMA only supported on gfx11 and gfx12"); std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); @@ -675,9 +684,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { SmallVector operands; wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), - adaptor.getSourceA(), operands); + adaptor.getSourceA(), op.getSourceA(), operands); wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), - adaptor.getSourceB(), operands); + adaptor.getSourceB(), op.getSourceB(), operands); wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), op.getSubwordOffset(), op.getClamp(), operands); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 3943696364950f..63447baa31eb0c 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -234,9 +234,10 @@ LogicalResult WMMAOp::verify() { Type sourceAElemType = sourceVectorAType.getElementType(); Type destElemType = destVectorType.getElementType(); - bool isDestFloat = - (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16()); - bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16()); + bool isDestFloat = isa(destElemType); + bool isSrcFloat = + isa( + sourceAElemType); if (isDestFloat && !isSrcFloat) { return emitOpError("Expected float sources with float destination"); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir new file mode 100644 index 00000000000000..7b2b524d4af426 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s +func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) { + // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> + func.return +} diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index d902a82eeb9ea2..97b505746fc751 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -377,6 +377,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr, llvm.return %rsrc : !llvm.ptr<8> } +llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) + %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) + %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + + llvm.return %r0 : vector<8 x f32> +} + llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>, %offset : i32, %soffset : i32, %vdata1 : i32,