From 1cf7b1b31cde8c62611e421becd4648c7284d76c Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 6 Nov 2024 09:21:06 -0800 Subject: [PATCH] [BACKEND] Get rid of unpack/pack I32 (#5044) - Removed functions related to unpacking and packing I32 values. - Updated utilities to handle conversion of mxfp4 values without packing/unpacking I32. - Move the register value ordering logic from the element-wise operation lowering to the dot operation lowering. - Use linear layout to handle conversions between almost all distributed layouts. - Clean up data loading and mma computation involving `repN`, `repK`, and `repM`. --- .../TritonGPUToLLVM/ElementwiseOpToLLVMBase.h | 11 -- .../Conversion/TritonGPUToLLVM/Utility.h | 73 +-------- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 6 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 27 +--- .../DecomposeUnsupportedConversions.cpp | 4 + .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 140 +----------------- .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 2 - .../TritonGPUToLLVM/TypeConverter.cpp | 19 +-- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 69 +++++---- lib/Dialect/TritonGPU/IR/Dialect.cpp | 104 ++++++------- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 2 +- python/test/unit/language/test_core.py | 23 +++ test/Conversion/tritongpu_to_llvm.mlir | 18 +-- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 13 +- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 33 +---- .../SharedToDotOperandMMAv2OrV3.cpp | 87 ++++++----- .../DecomposeUnsupportedConversions.cpp | 5 + .../DotOpToLLVM/MMAv2.cpp | 131 ++++++++-------- .../LoadStoreOpToLLVM.cpp | 13 +- .../UpcastMXFPToLLVM.cpp | 57 +------ 20 files changed, 299 insertions(+), 538 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h index 8c7ab98316..c37917a35d 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h @@ -15,9 +15,6 @@ namespace mlir::triton { namespace gpu { -SmallVector reorderValues(const SmallVector &values, Type inType, - Type ouType); - Type getElementType(Value value); class MultipleOperandsRange @@ -179,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - subOperands = unpackI32s(subOperands, argTy, rewriter, loc, - this->getTypeConverter()); allOperands.resize(subOperands.size()); for (auto v : llvm::enumerate(subOperands)) allOperands[v.index()].push_back(v.value()); @@ -201,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern { } it += curr.size(); } - if (op->getNumOperands() > 0) { - auto argTy = op->getOperand(0).getType(); - resultVals = reorderValues(resultVals, argTy, resultTy); - } resultVals = maybeDeduplicate(op, resultVals); - resultVals = - packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = packLLElements(loc, this->getTypeConverter(), resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index f0f62b8e83..35f1303fa1 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -396,10 +396,14 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, // MXFP utilities // ----------------------------------------------------------------------- -// Convert one int8, which contain, 2 packed mxfp4 values, into 2 bf16 -// standalone values and returns them as a pair for (high 4 bits, low 4 bits). -std::pair convertMxfp4x2ToBf16x2(RewriterBase &rewriter, - Location loc, Value v); +// Convert each value, which is an int8 containing 2 packed mxfp4 values, +// into 2 standalone bf16 values +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef values); + +// Scale a mxfp4 value by a given scale. +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale); + } // namespace LLVM /* ------------------------------------ */ @@ -1397,67 +1401,6 @@ inline Value getStructFromSharedMemoryObject(Location loc, return llvmStruct; } -// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer -// instructions to pack & unpack sub-word integers. A workaround is to -// store the results of tensors with dot operand encodings in i32 to -// facilitate instructions such as `ldmatrix`. -// -// TODO: Confirm if the problem is still there. -inline bool requiresI32Conversion(Type type) { - auto tensorTy = dyn_cast(type); - if (!tensorTy) - return false; - auto dotOpEnc = dyn_cast(tensorTy.getEncoding()); - if (!dotOpEnc) - return false; - auto parent = dyn_cast(dotOpEnc.getParent()); - if (!(parent && parent.getVersionMajor() < 3)) - return false; - return true; -} - -inline SmallVector packI32s(const SmallVector &inValues, - Type type, RewriterBase &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter) { - if (!requiresI32Conversion(type)) - return inValues; - Type eltTy = - typeConverter->convertType(cast(type).getElementType()); - - SmallVector outValues; - int vecWidth = 32 / eltTy.getIntOrFloatBitWidth(); - auto vecTy = vec_ty(eltTy, vecWidth); - for (int i = 0; i < inValues.size(); i += vecWidth) { - Value vec = undef(vecTy); - for (int j = 0; j < vecWidth; j++) { - vec = insert_element(vec, inValues[i + j], i32_val(j)); - } - outValues.push_back(bitcast(vec, i32_ty)); - } - return outValues; -} - -inline SmallVector unpackI32s(const SmallVector &inValues, - Type type, RewriterBase &rewriter, - Location loc, - const LLVMTypeConverter *typeConverter) { - if (!requiresI32Conversion(type)) - return inValues; - Type eltTy = - typeConverter->convertType(cast(type).getElementType()); - - SmallVector outValues; - for (auto v : inValues) { - auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecTy); - for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - inline SmallVector unpackLLElements(Location loc, Value llvmStruct, RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 382bc23182..8f9a1a850f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1199,8 +1199,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: bool isAmpere() const; bool isHopper() const; - unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef shape) const; - // Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor std::tuple decodeVoltaLayoutStates() const; @@ -1217,8 +1215,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: SmallVector getMMAv1Rep(int opIdx) const; SmallVector getMMAv1ShapePerWarp(int opIdx) const; int getMMAv1Vec(int opIdx) const; - SmallVector getMMAv2OrV3RepForOperand(ArrayRef shape, - int bitwidth, int kWidth, int opIdx) const; + SmallVector getRepForOperand(ArrayRef shape, + int bitwidth, int opIdx) const; bool supportReduction() const { if (isAmpere() || isHopper()) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 65ee8cc002..3fcae48978 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -328,20 +328,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } else { // Cast 5. The two layouts are equivalent. We should probably remove // these in RemoveLayoutConversion. - auto dstCvt = requiresI32Conversion(dstTy); - auto srcCvt = requiresI32Conversion(srcTy); - if (dstCvt || srcCvt) { - auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(), - getTypeConverter()); - inVals = - packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter()); - auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals, - rewriter, op.getType()); - rewriter.replaceOp(op, res); - } else { - rewriter.replaceOp(op, adaptor.getSrc()); - } + rewriter.replaceOp(op, adaptor.getSrc()); return success(); } } @@ -358,7 +345,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); SmallVector outVals(numRegs); for (int i = 0; i < numRegs; i++) { // Remove free masks from the register index @@ -371,7 +357,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion : idx; outVals[i] = inVals[srcIdx]; } - outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); @@ -406,11 +391,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (useLegacyMMAConversion) { return false; } - // FIXME [Dot LL] - // Enabling LL path for buggy kWidth path - bool largeKWidth = - dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; - return largeKWidth && nvidiaMma.isAmpere(); + if (nvidiaMma.isAmpere()) { + return true; + } } return false; } @@ -454,7 +437,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion inVals[it.index()] = ptrtoint(llvmElemTy, it.value()); } } - inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter()); // Pretty sure this is the identity function ATM // It'd be better to simply call `quotient({kBlock})` and @@ -474,7 +456,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter()); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1346cc143e..74b2767f0d 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -90,6 +90,10 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcBlocked && dstDotOp) { + auto dotParent = dyn_cast(dstDotOp.getParent()); + if (dotParent && dotParent.isAmpere()) { + return; + } Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 1b7088870c..632ccf1084 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -11,138 +11,23 @@ using namespace mlir::triton::gpu; namespace mlir::triton::gpu { -namespace { - -bool isDotOpTensorAndPacked(Type srcTy) { - auto tensorTy = dyn_cast(srcTy); - if (!tensorTy) - return false; - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!encoding) - return false; - auto parentEnc = dyn_cast(encoding.getParent()); - // By code convention, values for Hopper's dotOp-encoded tensors are not - // packed - if (!parentEnc || parentEnc.isHopper()) - return false; - return true; -} - -} // namespace - Type getElementType(Value value) { auto type = value.getType(); if (auto tensorType = dyn_cast(type)) return tensorType.getElementType(); return type; } -// MMA encoding has a different order depending on the element's bit width; -// reorder if we're in this case. -SmallVector reorderValues(const SmallVector &values, Type inType, - Type ouType) { - auto inTensorTy = dyn_cast(inType); - auto ouTensorTy = dyn_cast(ouType); - if (!inTensorTy || !ouTensorTy) - return values; - auto inEncoding = dyn_cast(inTensorTy.getEncoding()); - auto ouEncoding = dyn_cast(ouTensorTy.getEncoding()); - assert(inEncoding == ouEncoding); - if (!inEncoding) - return values; - // If the parent of the dot operand is in block encoding, we don't need to - // reorder elements - auto parentEncoding = dyn_cast(ouEncoding.getParent()); - if (!parentEncoding || parentEncoding.isHopper()) - return values; - size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); - size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); - auto ouEltTy = ouTensorTy.getElementType(); - if (inBitWidth == ouBitWidth) - return values; - if (inBitWidth == 16 && ouBitWidth == 32) { - // Register layout conversion: - // - // [0, 1], [4, 5] ⟶ [0], [1], [4], [5] - // [2, 3], [6, 7] [2], [3], [6], [7] - // - // Original access order: - // - // [0, 1], [2, 3], [4, 5], [6, 7] - // - // Transformed access order: - // - // [0], [2], [1], [3], [4], [6], [5], [7] - SmallVector ret; - for (unsigned i = 0; i < values.size(); i += 8) { - ret.push_back(values[i]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 6]); - ret.push_back(values[i + 5]); - ret.push_back(values[i + 7]); - } - return ret; - } - if (inBitWidth == 8 && ouBitWidth == 16) { - // Register layout conversion: - // - // [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11] - // [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15] - // - // Original access order: - // - // [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15] - // - // Transformed access order: - // - // [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15] - SmallVector ret; - for (unsigned i = 0; i < values.size(); i += 16) { - ret.push_back(values[i]); - ret.push_back(values[i + 1]); - ret.push_back(values[i + 4]); - ret.push_back(values[i + 5]); - ret.push_back(values[i + 2]); - ret.push_back(values[i + 3]); - ret.push_back(values[i + 6]); - ret.push_back(values[i + 7]); - ret.push_back(values[i + 8]); - ret.push_back(values[i + 9]); - ret.push_back(values[i + 12]); - ret.push_back(values[i + 13]); - ret.push_back(values[i + 10]); - ret.push_back(values[i + 11]); - ret.push_back(values[i + 14]); - ret.push_back(values[i + 15]); - } - return ret; - } - llvm_unreachable("unimplemented code path"); -} int getNumElementsPerThreads(Type type, const LLVMTypeConverter *typeConverter) { int numElemsPerThread = 1; - auto tensorTy = dyn_cast(type); - if (!tensorTy) - return numElemsPerThread; - auto structType = - dyn_cast(typeConverter->convertType(type)); - if (structType) { - numElemsPerThread = structType.getBody().size(); + if (auto tensorTy = dyn_cast(type)) { + auto structType = + dyn_cast(typeConverter->convertType(type)); + if (structType) + numElemsPerThread = structType.getBody().size(); } - auto encoding = dyn_cast(tensorTy.getEncoding()); - if (!(encoding && isa(encoding.getParent()))) - return numElemsPerThread; - auto eltType = tensorTy.getElementType(); - assert(eltType.getIntOrFloatBitWidth() <= 32 && - "Only support element type with bit width <= 32 in dot operand mma " - "layout"); - // dot operand data are packed into i32 elements so use the following formula - // to get the number of elements per thread. - return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread; + return numElemsPerThread; } } // namespace mlir::triton::gpu @@ -473,8 +358,7 @@ struct ElementwiseInlineAsmOpConversion for (auto operand : adaptor.getOperands()) { auto argTy = op->getOperand(0).getType(); auto subOperands = unpackLLElements(loc, operand, rewriter); - unpackedOperands.push_back( - unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter())); + unpackedOperands.push_back(subOperands); } int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(), @@ -527,16 +411,6 @@ struct ElementwiseInlineAsmOpConversion // Reorder and pack the results. SmallVector outs; for (int i = 0; i < unpackedResults.size(); i++) { - // We reordered all the inputs so they match operand 0. Reorder the - // outputs accordingly. - if (op->getNumOperands() > 0) { - unpackedResults[i] = reorderValues( - unpackedResults[i], /*inType=*/op->getOperand(0).getType(), - /*ouType=*/op->getResult(i).getType()); - } - auto dstTy = op->getResult(i).getType(); - unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc, - getTypeConverter()); outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i], rewriter, op->getResult(i).getType())); } diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 38fa1bd623..fbd6248fe7 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -173,7 +173,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstLayout = dstTy.getEncoding(); assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) && "Unexpected rank of ConvertLayout(shared->distributed)"); - auto inOrd = getOrder(srcSharedLayout); auto smemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), @@ -183,7 +182,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { SmallVector outVals = loadSharedToDistributed( dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo); - outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter); Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index cc6d8875b5..8cac1efbff 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -70,29 +70,12 @@ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType( return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); } -Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( - TensorOrMemDesc type) { - auto ctx = type.getContext(); - Attribute layout = type.getEncoding(); - Type elemTy = convertType(type.getElementType()); - auto dotOpLayout = mlir::dyn_cast(layout); - if (!dotOpLayout) - return elemTy; - auto mmaParent = - mlir::dyn_cast(dotOpLayout.getParent()); - if (!mmaParent || mmaParent.isHopper()) - return elemTy; - int bitwidth = elemTy.getIntOrFloatBitWidth(); - assert(bitwidth <= 32); - return IntegerType::get(ctx, 32); -} - Type TritonGPUToLLVMTypeConverter::convertTritonTensorType( RankedTensorType type, const TargetInfoBase &targetInfo) { auto ctx = type.getContext(); Attribute layout = type.getEncoding(); SmallVector shape(type.getShape().begin(), type.getShape().end()); - Type eltType = getElementTypeForStruct(cast(type)); + Type eltType = convertType(type.getElementType()); if (auto shared_layout = mlir::dyn_cast(layout)) { SmallVector types; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index fdca492875..b5ab3601ea 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -862,32 +862,49 @@ SmallVector getWrappedMultiDimOffset( return multiDimOffsetWrapped; } -std::pair convertMxfp4x2ToBf16x2(RewriterBase &rewriter, - Location loc, Value v) { - auto em0 = and_(v, i8_val(0x70)); - auto em1 = and_(v, i8_val(0x7)); - Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), - shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); - Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), - shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); - - // Three cases: - // 1) x is normal and non-zero: Correct bias - v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), - add(v0, i16_val((127 - 1) << 7)), v0); - v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), - add(v1, i16_val((127 - 1) << 7)), v1); - - // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in - // bf16 - v0 = select(icmp_eq(em0, i8_val(0x10)), - or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0); - v1 = select(icmp_eq(em1, i8_val(0x1)), - or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1); - // 3) x is zero, nothing to do - - return {v0, v1}; -} +SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, + ArrayRef values) { + SmallVector results; + for (auto v : values) { + auto em0 = and_(v, i8_val(0x70)); + auto em1 = and_(v, i8_val(0x7)); + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), + add(v0, i16_val((127 - 1) << 7)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), + add(v1, i16_val((127 - 1) << 7)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in + // bf16 + v0 = bitcast(select(icmp_eq(em0, i8_val(0x10)), + or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0), + bf16_ty); + v1 = bitcast(select(icmp_eq(em1, i8_val(0x1)), + or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1), + bf16_ty); + // 3) x is zero, nothing to do + results.push_back(v0); + results.push_back(v1); + } + return results; +} + +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value vBf16 = bitcast(v, bf16_ty); + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + Value scaledBf16 = fmul(vBf16, scaleBf16); + // Account for NaN in the scale as per the mxfp specification. + return select(scaleIsNan, nanBf16, scaledBf16); +}; } // namespace LLVM } // namespace mlir diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d0365b4cee..8462c24aea 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -898,36 +898,6 @@ NvidiaMmaEncodingAttr::getElemsPerThread(ArrayRef shape, return elemsPerThread; } -unsigned NvidiaMmaEncodingAttr::getElemsPerThreadOfOperand( - int opIdx, ArrayRef shape) const { - size_t rank = shape.size(); - assert(rank == 2 && "Unexpected rank of mma layout"); - auto shapePerCTA = getShapePerCTA(*this, shape); - int res = 0; - if (isVolta()) { - llvm_unreachable( - "getElemsPerThreadOfOperand() not supported for version 1"); - } else if (isAmpere()) { - llvm_unreachable( - "getElemsPerThreadOfOperand() not supported for version 2"); - } else if (isHopper()) { - auto wpt = getWarpsPerCTA(); - auto instrMNK = getInstrShape(); - if (opIdx == 0) { - int repM = ceil(shapePerCTA[0], instrMNK[0] * wpt[0]); - int repK = ceil(shapePerCTA[1], instrMNK[2]); - return 8 * repM * repK; - - } else if (opIdx == 1) { - int repK = ceil(shapePerCTA[0], instrMNK[2]); - int repN = ceil(shapePerCTA[1], instrMNK[1] * wpt[1]); - // benzh@ here need more check - return 4 * std::max(instrMNK[1] / 32, 1) * repK * repN; - } - } - return res; -} - unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { return product(getElemsPerThread(shape, eltTy)); @@ -950,25 +920,41 @@ unsigned SharedEncodingAttr::getTotalElemsPerThread(ArrayRef shape, SmallVector DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { + auto rank = shape.size(); + assert(rank == 2 || rank == 3); - if (auto parent = mlir::dyn_cast(getParent())) { - auto rank = shape.size(); - assert(rank == 2 || rank == 3); - - auto idx = getOpIdx(); - assert(idx == 0 || idx == 1); - - SmallVector elemsPerThread(rank); + auto idx = getOpIdx(); + assert(idx == 0 || idx == 1); - auto kWidth = getKWidth(); - auto rep = parent.getRepForOperand(shape, kWidth, idx); + SmallVector elemsPerThread(rank); + auto parent = getParent(); + auto kWidth = getKWidth(); + if (auto mfma = mlir::dyn_cast(parent)) { + auto rep = mfma.getRepForOperand(shape, kWidth, idx); if (rank == 3) elemsPerThread[0] = rep[0]; elemsPerThread[rank - 2] = (idx == 0) ? rep[1] : rep[1] * kWidth; elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2]; - return elemsPerThread; + } else if (auto mma = mlir::dyn_cast(parent)) { + if (mma.isAmpere()) { + auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth(); + auto rep = mma.getRepForOperand(shape, bitwidth, idx); + auto sizePerThread = getSizePerThread(); + auto elemsPerKRep = 32 / bitwidth * 2; + if (rank == 3) + elemsPerThread[0] = rep[0]; + elemsPerThread[rank - 2] = + (idx == 0) + ? rep[1] * sizePerThread[rank - 2] + : std::max(rep[1] * elemsPerKRep, sizePerThread[rank - 2]); + elemsPerThread[rank - 1] = + (idx == 0) + ? std::max(rep[2] * elemsPerKRep, sizePerThread[rank - 1]) + : rep[2] * sizePerThread[rank - 1]; + return elemsPerThread; + } } llvm_unreachable("getElemsPerThread is not supported for dot operand"); @@ -978,6 +964,10 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef shape, unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, Type eltTy) const { if (auto mmaParent = mlir::dyn_cast(getParent())) { + if (auto nvidiaMmaParent = mlir::dyn_cast(mmaParent); + nvidiaMmaParent && nvidiaMmaParent.isAmpere()) { + return product(getElemsPerThread(shape, eltTy)); + } return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(), getOpIdx()); } @@ -2021,9 +2011,9 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const { int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const { return 2 * getMMAv1Rep(opIdx)[opIdx]; } -SmallVector NvidiaMmaEncodingAttr::getMMAv2OrV3RepForOperand( - ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { - assert(isAmpere() || (isHopper() && opIdx == 0)); +SmallVector +NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, + int opIdx) const { auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); @@ -2036,17 +2026,18 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2OrV3RepForOperand( ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) : 1; - if (opIdx == 0) + if (opIdx == 0) { return {numRepBatch, - std::max(1, shape[rank - 2] / + std::max(1, /*repM=*/shape[rank - 2] / (shapePerWarp[1] * warpsPerCTA[rank - 2])), - std::max(1, shape[rank - 1] / shapePerWarp[3])}; - else { + std::max(1, /*repK=*/shape[rank - 1] / shapePerWarp[3])}; + } else { assert(opIdx == 1); - return {numRepBatch, - std::max(1, shape[rank - 2] / shapePerWarp[3]), - std::max(1, shape[rank - 1] / (shapePerWarp[2] * - warpsPerCTA[rank - 1]))}; + return { + numRepBatch, + std::max(1, /*repK=*/shape[rank - 2] / shapePerWarp[3]), + std::max(1, /*repN=*/shape[rank - 1] / + (shapePerWarp[2] * warpsPerCTA[rank - 1]))}; } } @@ -2065,15 +2056,6 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( // kWidth elements for each quadrant. WGMMA is repeated repM * repK times. return 4 * kWidth * repM * repK; } - // A100 - if (isAmpere()) { - auto rep = getMMAv2OrV3RepForOperand( - shapePerCTA, eltTy.getIntOrFloatBitWidth(), kWidth, opIdx); - if (opIdx == 0) - return 4 * rep[0] * rep[1] * rep[2]; - if (opIdx == 1) - return 4 * rep[0] * rep[1] * std::max(rep[2] / 2, 1); - } // V100 if (isVolta()) { bool isRow = getMMAv1IsRow(opIdx); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index f3a0cb5ada..7af52bc541 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -189,7 +189,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { // elements distribution to the order of higher precision primitives. As a // result, kwidth can be the bitwidth of the lower precision primitive. // Conversely, in the downcasting scenario, no reordering is performed, - // making it directory use the lower precision primitive. + // making it directly use the lower precision primitive. static int computeOrigBitWidth(Value x) { int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth(); int origBitWidth = finalBitWidth; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3fe93ddbd2..6e857cdbd1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -148,6 +148,17 @@ def __str__(self): return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>" +class DotOperandLayout: + + def __init__(self, parent, op_idx, k_width): + self.parent = parent + self.op_idx = op_idx + self.k_width = k_width + + def __str__(self): + return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>" + + class BlockedLayout: def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): @@ -5221,6 +5232,14 @@ def kernel(Out): BlockedLayout([4, 1], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=8), + DotOperandLayout(parent=MmaLayout([2, 0], [2, 2], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=8), MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), ] @@ -5258,6 +5277,10 @@ def compute_scratch_buffer_shape(src_layout, dst_layout, shape): def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, tmp_path: pathlib.Path): if str(src_layout) == str(dst_layout): pytest.skip() + if (isinstance(src_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)) or (isinstance(dst_layout, DotOperandLayout) + and isinstance(interm_layout, SharedLayout)): + pytest.skip("DotOperandLayout <-> SharedLayout conversion is not completely supported") if is_hip(): try: scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N)) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 679a18cd9b..3537b4e670 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1398,9 +1398,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func @matmul_tf32_cst_b(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a: tensor<32x16xf32, #dot_operand_a>, %c: tensor<32x32xf32, #mma>) { // CHECK: %[[CST:.+]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 - // CHECK: %[[BC:.+]] = llvm.bitcast %[[CST]] : f32 to i32 - // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[BC:.+]] = llvm.bitcast %[[CST]] : f32 to f32 + // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> %28 = tt.dot %a, %b_mat, %c, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> %38 = triton_gpu.convert_layout %28 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> @@ -1419,16 +1419,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: matmul_f16_cst_operands tt.func public @matmul_f16_cst_operands(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> - // CHECK: %[[C1f:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16 - // CHECK: %[[Ci16:.+]] = llvm.bitcast %[[C1f]] : f16 to i16 - // CHECK: %[[U:.+]] = llvm.mlir.undef : vector<2xi16> + // CHECK: %[[U:.+]] = llvm.mlir.undef : vector<2xf16> // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[V0:.+]] = llvm.insertelement %[[Ci16]], %[[U]][%[[C0]] : i32] : vector<2xi16> + // CHECK: %[[V0:.+]] = llvm.insertelement %{{.*}}, %[[U]][%[[C0]] : i32] : vector<2xf16> // CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[V1:.+]] = llvm.insertelement %[[Ci16]], %[[V0]][%[[C1]] : i32] : vector<2xi16> - // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xi16> to i32 - // CHECK: %[[SU:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> - // CHECK: llvm.insertvalue %[[BC]], %[[SU]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %[[V1:.+]] = llvm.insertelement %{{.*}}, %[[V0]][%[[C1]] : i32] : vector<2xf16> + // CHECK: %[[BC:.+]] = llvm.bitcast %[[V1]] : vector<2xf16> to i32 %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index f8edb7f93c..43a334d5cf 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -41,11 +41,22 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); auto threadOrder = triton::gpu::getThreadOrder(layout); - auto warpOrder = triton::gpu::getWarpOrder(layout); + SmallVector warpOrder(rank); + if (auto enc = dyn_cast(layout)) { + warpOrder = + triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1); + } else { + warpOrder = triton::gpu::getWarpOrder(layout); + } auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); + // TODO: [DOT LL] + // The delinearize function is not entirely correct for certain layouts, + // such as wgmma. The correct approach is to convert a legacy layout to its + // corresponding linear layout and use the linear layout's + // getFreeVariableMasks to identify redundant elements. SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 07fa634ec7..f8165a7693 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -19,17 +19,6 @@ using namespace mlir::triton::gpu; namespace { -Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, - Value scale) { - Value vBf16 = bitcast(v, bf16_ty); - Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); - Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); - Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); - Value scaledBf16 = fmul(vBf16, scaleBf16); - // Account for NaN in the scale as per the mxfp specification. - return select(scaleIsNan, nanBf16, scaledBf16); -}; - class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { private: const TargetInfoBase &targetInfo; @@ -83,7 +72,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value laneId = urem(tid, warpSize); if (isPacked) - xVals = unpackFP4Elements(loc, rewriter, xVals); + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); // Given that MFMA layout for the A tensor arranges thread in a column-major // manner, for the current tid, it's at row (tid % mDim). When we set up @@ -110,7 +99,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; - xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); } } } else { @@ -132,7 +122,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; - xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); + xVals[index] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); } } } @@ -142,20 +133,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { rewriter.replaceOp(op, result); return success(); } - -private: - SmallVector unpackFP4Elements(Location loc, RewriterBase &rewriter, - ArrayRef packed) const { - // Split every fp4x2 into 2 bf16 values. - llvm::SmallVector unpacked; - unpacked.reserve(packed.size() * 2); - for (Value v : packed) { - auto [e0, e1] = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, v); - unpacked.push_back(e0); - unpacked.push_back(e1); - } - return unpacked; - } }; } // anonymous namespace diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp index 8f1fcc1f70..4c99a44dff 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2OrV3.cpp @@ -528,47 +528,59 @@ Type getSharedMemTy(Type argType) { llvm::report_fatal_error("mma16816 data type not supported"); } -std::vector unpackInt(const std::vector &inValues, Type elTy, - ConversionPatternRewriter &rewriter, Location loc, - const LLVMTypeConverter *typeConverter) { - const int inBitWidth = inValues[0].getType().getIntOrFloatBitWidth(); - std::vector outValues; - for (auto v : inValues) { - // cast i32 to appropriate eltType vector and extract elements - auto eltType = typeConverter->convertType(elTy); - auto vecType = - vec_ty(eltType, inBitWidth / eltType.getIntOrFloatBitWidth()); - auto vec = bitcast(v, vecType); - for (int i = 0; i < inBitWidth / eltType.getIntOrFloatBitWidth(); i++) { - outValues.push_back(extract_element(vec, i32_val(i))); - } - } - return outValues; -} - Value composeValuesToDotOperandLayoutStruct( - const ValueTable &vals, int batch, int n0, int n1, + const ValueTable &vals, int batch, int repOuter, int repK, const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter, Type elTy, bool isHopper) { + ConversionPatternRewriter &rewriter, Type eltTy, int kWidth, bool isHopper, + bool isA) { + auto bitwidth = eltTy.getIntOrFloatBitWidth(); + assert(32 >= bitwidth && "only support 32-bit or less"); + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(eltTy, numElemsPerVec); + // FIXME: Fix the hopper path + // FIXME: [DOT LL] + // `kWidth` specifies the number of contiguous elements each thread will load. + // Loaded elements are packed into a vector of int32, which will then be + // unpacked into individual elements. + // `kIters` specifies the number of contiguous int32 elements each thread + // should load. + auto kIters = isHopper ? 1 : kWidth / (32 / bitwidth); + std::vector elems; - for (int b = 0; b < batch; ++b) - for (int m = 0; m < n0; ++m) - for (int k = 0; k < n1; ++k) { - elems.push_back(vals.at({b, 2 * m, 2 * k})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k})); - elems.push_back(vals.at({b, 2 * m, 2 * k + 1})); - elems.push_back(vals.at({b, 2 * m + 1, 2 * k + 1})); + auto unpackVec = [&](int b, int m, int k) { + for (auto kIter = 0; kIter < kIters; ++kIter) { + auto val = vals.at({b, m, k + kIter}); + auto vec = bitcast(val, vecTy); + for (auto i = 0; i < numElemsPerVec; ++i) { + elems.push_back(extract_element(eltTy, vec, i32_val(i))); } - assert(!elems.empty()); + } + }; - if (isHopper) { - elems = unpackInt(elems, elTy, rewriter, loc, typeConverter); + // Loading A tile is different from loading B tile since each tile of A is + // 16x16 while B is 16x8. + if (isA) { + for (int b = 0; b < batch; ++b) + for (int m = 0; m < repOuter; ++m) + for (int k = 0; k < std::max(repK / kIters, 1); ++k) { + unpackVec(b, 2 * m, kIters * 2 * k); + unpackVec(b, 2 * m + 1, kIters * 2 * k); + unpackVec(b, 2 * m, kIters * (2 * k + 1)); + unpackVec(b, 2 * m + 1, kIters * (2 * k + 1)); + } + } else { + for (int b = 0; b < batch; ++b) + for (int n = 0; n < repOuter; ++n) + for (int k = 0; k < std::max(repK / kIters, 1); ++k) { + unpackVec(b, n, kIters * 2 * k); + unpackVec(b, n, kIters * (2 * k + 1)); + } } + assert(!elems.empty()); - Type elemTy = elems[0].getType(); - MLIRContext *ctx = elemTy.getContext(); + MLIRContext *ctx = eltTy.getContext(); Type structTy = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(elems.size(), elemTy)); + ctx, SmallVector(elems.size(), eltTy)); auto result = packLLElements(loc, typeConverter, elems, rewriter, structTy); return result; } @@ -658,8 +670,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth; int kWidth = encoding.getKWidth(); - auto numRep = mmaLayout.getMMAv2OrV3RepForOperand( - shapePerCTA, mmaBitwidth, kWidth, encoding.getOpIdx()); + auto numRep = + mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, encoding.getOpIdx()); auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); auto warpOrder = mmaLayout.getWarpOrder(); @@ -704,9 +716,10 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc, loadFn(b, 2 * m, 2 * k); // Format the values to LLVM::Struct to passing to mma codegen. + Type eltTy = typeConverter->convertType(descTy.getElementType()); return composeValuesToDotOperandLayoutStruct( - vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter, - descTy.getElementType(), /*unpack=*/isHopper); + vals, numRepBatch, isA ? numRep[1] : numRep[2], numRepK, typeConverter, + loc, rewriter, eltTy, kWidth, isHopper, isA); } template diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 6129d77f17..1aa2b516a5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -75,6 +75,11 @@ struct DecomposeUnsupportedConversions // we have enabled the new layout conversion for all the cases. auto nvidiaShortCutFn = [&](RankedTensorType srcTy, RankedTensorType dstTy) { + auto nvidiaMma = dyn_cast(srcTy.getEncoding()); + // Supported mma to dot conversion + if (nvidiaMma && nvidiaMma.isAmpere()) + return true; + // No need to decompose if shared memory is not needed return matchMmaV3AndDotOperandLayout(srcTy, dstTy) || cvtReordersRegisters(srcTy, dstTy); }; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 508f03227c..1c98fc5f88 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -59,56 +59,73 @@ Value loadC(Value tensor, Value llTensor, ValueTableV2 getValuesFromDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, - ConversionPatternRewriter &rewriter, Value value, int batch, int n0, int n1, - RankedTensorType type) { + ConversionPatternRewriter &rewriter, Value value, int batch, int repOuter, + int repK, RankedTensorType type) { auto elems = unpackLLElements(loc, value, rewriter); + auto eltTy = type.getElementType(); int offset{}; ValueTableV2 vals; + auto bitwidth = eltTy.getIntOrFloatBitWidth(); + auto numElemsPerVec = 32 / bitwidth; + auto vecTy = vec_ty(eltTy, numElemsPerVec); + + auto packVec = [&](std::array dstIdx) { + Value vec = undef(vecTy); + for (auto i = 0; i < numElemsPerVec; ++i) { + vec = insert_element(vec, bitcast(elems[offset + i], eltTy), i32_val(i)); + } + vals[dstIdx] = bitcast(vec, i32_ty); + offset += numElemsPerVec; + }; - // FIXME [Dot LL] - // [ez] Generalize the logic below for kWidth * elemBitWidth > 32 auto dot = cast(type.getEncoding()); - auto largeK = dot.getKWidth() == 8 && - cast(dot.getParent()).isAmpere(); + auto kWidth = dot.getKWidth(); + auto largeK = bitwidth * kWidth > 32; if (largeK) { + // For layouts with a large K dimension, the original register layout needs + // to be divided into multiple MMAs, where each MMA has contiguous 32 bits + // along the K dimension per thread. + // Using kWidth = 8 and bitwidth = 2 as an example, + // we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the + // K dimension. llvm::SmallVector si; - // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K if (dot.getOpIdx() == 0) { // Original register layout: // - // [0, 1, 2, 3], [8, 9, 10, 11] - // [4, 5, 6, 7], [12, 13, 14, 15] - // - // Each element in the layout consists of two bf16 values. - // For example, the row [0, 1, 2, 3] expands to: + // [0, 1, 2, 3, 4, 5, 6, 7], [16, 17, 18, 19, 20, 21, 22, 23, 23] + // [8, 9, 10, 11, 12, 13, 14, 15], [24, 25, 26, 27, 28, 29, 30, 31] // - // [[0/0, 0/1], [1/0, 1/1], [2/0, 2/1], [3/0, 3/1]] - // - // Here, 0/0 refers to the first half of element 0, and 0/1 refers to the - // second half, matching kWidth = 8. + // Each element in the layout is a single bf16. // // To derive four independent MMA operations, a stride of 4 is applied to // the original register layout: // - // 1st MMA: [0, 4, 8, 12] - // 2nd MMA: [1, 5, 9, 13] - // 3rd MMA: [2, 6, 10, 14] - // 4th MMA: [3, 7, 11, 15] - si = llvm::SmallVector{0, 4, 8, 12, 1, 5, 9, 13, - 2, 6, 10, 14, 3, 7, 11, 15}; + // 1st MMA: [[0, 1], [8, 9], [16, 17], [24, 25]] + // 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]] + // 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]] + // 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]] + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 4; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } } else { // Original register layout: // - // [0, 1, 2, 3]^T, [4, 5, 6, 7]^T + // [0, 1, 2, 3, 4, 5, 6, 7]^T, [8, 9, 10, 11, 12, 13, 14, 15]^T // // A stride of 4 is applied to derive four independent MMA operations: // - // 1st MMA: [0, 4] - // 2nd MMA: [1, 5] - // 3rd MMA: [2, 6] - // 4th MMA: [3, 7] - si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; + // 1st MMA: [[0, 1], [8, 9]] + // 2nd MMA: [[2, 3], [10, 11]] + // 3rd MMA: [[4, 5], [12, 13]] + // 4th MMA: [[6, 7], [14, 15]] + for (size_t kRep = 0; kRep < kWidth / numElemsPerVec; ++kRep) + for (size_t tile = 0; tile < 2; ++tile) + for (size_t e = 0; e < numElemsPerVec; ++e) { + si.push_back(kRep * numElemsPerVec + tile * kWidth + e); + } } auto step = si.size(); @@ -119,30 +136,25 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( } std::copy(perm.begin(), perm.end(), elems.begin() + i * step); } - - if (dot.getOpIdx() == 1) { - int elemsInTile = dot.getKWidth(); - // n0 is unrolled in the legacy path, which makes no sense - n0 *= 2; - for (auto b = 0; b < batch; ++b) - for (auto i = 0; i < n0; ++i) - for (auto j = 0; j < n1; ++j) { - vals[{b, i, 2 * j}] = elems[offset++]; - vals[{b, i, 2 * j + 1}] = elems[offset++]; - } - return vals; - } } - for (auto b = 0; b < batch; ++b) - for (auto i = 0; i < n0; ++i) { - for (auto j = 0; j < n1; j++) { - vals[{b, 2 * i, 2 * j}] = elems[offset++]; - vals[{b, 2 * i + 1, 2 * j}] = elems[offset++]; - vals[{b, 2 * i, 2 * j + 1}] = elems[offset++]; - vals[{b, 2 * i + 1, 2 * j + 1}] = elems[offset++]; - } - } + if (dot.getOpIdx() == 0) { + for (auto b = 0; b < batch; ++b) + for (auto m = 0; m < repOuter; ++m) + for (auto k = 0; k < repK; ++k) { + packVec({b, 2 * m, 2 * k}); + packVec({b, 2 * m + 1, 2 * k}); + packVec({b, 2 * m, 2 * k + 1}); + packVec({b, 2 * m + 1, 2 * k + 1}); + } + } else { + for (auto b = 0; b < batch; ++b) + for (auto n = 0; n < repOuter; ++n) + for (auto k = 0; k < repK; ++k) { + packVec({b, n, 2 * k}); + packVec({b, n, 2 * k + 1}); + } + } return vals; } @@ -389,15 +401,11 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth(); auto dotOpA = cast(aTensorTy.getEncoding()); - auto repA = - cast(dotOpA.getParent()) - .getMMAv2OrV3RepForOperand(aShapePerCTA, bitwidth, dotOpA.getKWidth(), - dotOpA.getOpIdx()); + auto repA = cast(dotOpA.getParent()) + .getRepForOperand(aShapePerCTA, bitwidth, dotOpA.getOpIdx()); auto dotOpB = cast(bTensorTy.getEncoding()); - auto repB = - cast(dotOpB.getParent()) - .getMMAv2OrV3RepForOperand(bShapePerCTA, bitwidth, dotOpB.getKWidth(), - dotOpB.getOpIdx()); + auto repB = cast(dotOpB.getParent()) + .getRepForOperand(bShapePerCTA, bitwidth, dotOpB.getOpIdx()); assert(repA[2] == repB[1]); assert(repA[0] == repB[0]); @@ -407,13 +415,8 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); - // FIXME [Dot LL] - // max(repN / 2, 1) is wrong for repN = 1! - // This is also wrong in - // NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand auto hb = getValuesFromDotOperandLayoutStruct( - typeConverter, loc, rewriter, loadedB, repBatch, std::max(repN / 2, 1), - repK, bTensorTy); + typeConverter, loc, rewriter, loadedB, repBatch, repN, repK, bTensorTy); auto fc = unpackLLElements(loc, loadedC, rewriter); auto numMmaRets = dTensorTy.getElementType().getIntOrFloatBitWidth() / 8; int numCPackedElem = 4 / numMmaRets; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index afe7f98be8..a439b89270 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -39,11 +39,22 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); auto threadOrder = triton::gpu::getThreadOrder(layout); - auto warpOrder = triton::gpu::getWarpOrder(layout); + SmallVector warpOrder(rank); + if (auto enc = dyn_cast(layout)) { + warpOrder = + triton::gpu::getMatrixOrder(rank, /*rowMajor=*/enc.getOpIdx() == 1); + } else { + warpOrder = triton::gpu::getWarpOrder(layout); + } auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(32); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); + // TODO: [DOT LL] + // The delinearize function is not entirely correct for certain layouts, + // such as wgmma. The correct approach is to convert a legacy layout to its + // corresponding linear layout and use the linear layout's + // getFreeVariableMasks to identify redundant elements. SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 04518f1736..6cba3f45da 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -30,35 +30,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {} - llvm::SmallVector unpackFP4Elements(Location loc, - RewriterBase &rewriter, - ArrayRef vals) const { - - auto fp4x8ToBf16x2 = [&loc, &rewriter](Value v) { - llvm::SmallVector results(4); - for (int i = 0; i < 4; ++i) { - auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i))); - auto [e0, e1] = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, v_i); - // Swap as they come packed in big endian - results[i] = or_(zext(i32_ty, e0), shl(zext(i32_ty, e1), i32_val(16))); - } - return results; - }; - - // Split fp4x8 into 4 bf16x2 - llvm::SmallVector ret; - ret.reserve(vals.size() * 4); - for (int i = 0; i < vals.size(); ++i) { - auto vs = fp4x8ToBf16x2(vals[i]); - assert(vs.size() == 4); - for (auto v : vs) { - ret.push_back(v); - } - } - - return ret; - } - LogicalResult matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -78,27 +49,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - if (fpType == ScaleDotElemType::E2M1) { - xVals = unpackFP4Elements(loc, rewriter, xVals); - } - - auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { - // Split bf16x2 into 2 bf16, scale each of them, and pack them back - // TODO Is it true that the bfloats are always packed as bf16x2? - auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); - auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); - auto scaleIsNan = icmp_eq(s, i8_val(0xff)); - auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); - auto scaledBf16_0 = fmul(bf16_0, scaleBf16); - auto scaledBf16_1 = fmul(bf16_1, scaleBf16); - auto i16_0 = bitcast(scaledBf16_0, i16_ty); - auto i16_1 = bitcast(scaledBf16_1, i16_ty); - auto packed = - or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); - // Account for NaN in the scale as per the mxfp specification - auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); - return packed_nan; - }; + if (fpType == ScaleDotElemType::E2M1) + xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals); // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + @@ -116,8 +68,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), }; - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); + for (int j = 0; j < 32; ++j) { + xVals[32 * i + j] = + LLVM::mxfpScaleBf16(rewriter, loc, xVals[32 * i + j], si[j / 8]); } }