Skip to content

Commit

Permalink
Merge commit '1cf7b1b31cde8c62611e421becd4648c7284d76c'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Nov 14, 2024
2 parents e30e00f + 1cf7b1b commit 876ce90
Show file tree
Hide file tree
Showing 20 changed files with 289 additions and 544 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace mlir::triton {

namespace gpu {

SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -179,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -201,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
it += curr.size();
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
73 changes: 8 additions & 65 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,14 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
// MXFP utilities
// -----------------------------------------------------------------------

// Convert one int8, which contain, 2 packed mxfp4 values, into 2 bf16
// standalone values and returns them as a pair for (high 4 bits, low 4 bits).
std::pair<Value, Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter,
Location loc, Value v);
// Convert each value, which is an int8 containing 2 packed mxfp4 values,
// into 2 standalone bf16 values
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values);

// Scale a mxfp4 value by a given scale.
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);

} // namespace LLVM

/* ------------------------------------ */
Expand Down Expand Up @@ -1397,67 +1401,6 @@ inline Value getStructFromSharedMemoryObject(Location loc,
return llvmStruct;
}

// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of tensors with dot operand encodings in i32 to
// facilitate instructions such as `ldmatrix`.
//
// TODO: Confirm if the problem is still there.
inline bool requiresI32Conversion(Type type) {
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return false;
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!dotOpEnc)
return false;
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
if (!(parent && parent.getVersionMajor() < 3))
return false;
return true;
}

inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
auto vecTy = vec_ty(eltTy, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecTy);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
for (auto v : inValues) {
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecTy);
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
6 changes: 2 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1212,8 +1212,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isAmpere() const;
bool isHopper() const;

unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;

// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;

Expand All @@ -1230,8 +1228,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2OrV3RepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
Expand Down
11 changes: 3 additions & 8 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
: idx;
outVals[i] = inVals[srcIdx];
}
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -392,11 +391,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (useLegacyMMAConversion) {
return false;
}
// FIXME [Dot LL]
// Enabling LL path for buggy kWidth path
bool largeKWidth =
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
return largeKWidth && nvidiaMma.isAmpere();
if (nvidiaMma.isAmpere()) {
return true;
}
}
return false;
}
Expand Down Expand Up @@ -440,7 +437,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
}
}
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());

// Pretty sure this is the identity function ATM
// It'd be better to simply call `quotient({kBlock})` and
Expand All @@ -460,7 +456,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,10 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
auto dstDotOp =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (srcBlocked && dstDotOp) {
// FIXME [Dot LL]
// We support this one via LLs, as the LocalLoad path is buggy
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
bool largeKWidth =
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
if (mma.isAmpere() && largeKWidth) {
return;
}
auto dotParent = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent());
if (dotParent && dotParent.isAmpere()) {
return;
}

Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
auto tmpType = MemDescType::get(
Expand Down
140 changes: 7 additions & 133 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,138 +11,23 @@ using namespace mlir::triton::gpu;

namespace mlir::triton::gpu {

namespace {

bool isDotOpTensorAndPacked(Type srcTy) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return false;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!encoding)
return false;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
// By code convention, values for Hopper's dotOp-encoded tensors are not
// packed
if (!parentEnc || parentEnc.isHopper())
return false;
return true;
}

} // namespace

Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = dyn_cast<RankedTensorType>(type))
return tensorType.getElementType();
return type;
}
// MMA encoding has a different order depending on the element's bit width;
// reorder if we're in this case.
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType) {
auto inTensorTy = dyn_cast<RankedTensorType>(inType);
auto ouTensorTy = dyn_cast<RankedTensorType>(ouType);
if (!inTensorTy || !ouTensorTy)
return values;
auto inEncoding = dyn_cast<DotOperandEncodingAttr>(inTensorTy.getEncoding());
auto ouEncoding = dyn_cast<DotOperandEncodingAttr>(ouTensorTy.getEncoding());
assert(inEncoding == ouEncoding);
if (!inEncoding)
return values;
// If the parent of the dot operand is in block encoding, we don't need to
// reorder elements
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
if (!parentEncoding || parentEncoding.isHopper())
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
// Register layout conversion:
//
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
// [2, 3], [6, 7] [2], [3], [6], [7]
//
// Original access order:
//
// [0, 1], [2, 3], [4, 5], [6, 7]
//
// Transformed access order:
//
// [0], [2], [1], [3], [4], [6], [5], [7]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 7]);
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
// Register layout conversion:
//
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
//
// Original access order:
//
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
//
// Transformed access order:
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
return ret;
}
llvm_unreachable("unimplemented code path");
}

int getNumElementsPerThreads(Type type,
const LLVMTypeConverter *typeConverter) {
int numElemsPerThread = 1;
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return numElemsPerThread;
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (structType) {
numElemsPerThread = structType.getBody().size();
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (structType)
numElemsPerThread = structType.getBody().size();
}
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
return numElemsPerThread;
auto eltType = tensorTy.getElementType();
assert(eltType.getIntOrFloatBitWidth() <= 32 &&
"Only support element type with bit width <= 32 in dot operand mma "
"layout");
// dot operand data are packed into i32 elements so use the following formula
// to get the number of elements per thread.
return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread;
return numElemsPerThread;
}

} // namespace mlir::triton::gpu
Expand Down Expand Up @@ -473,8 +358,7 @@ struct ElementwiseInlineAsmOpConversion
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
unpackedOperands.push_back(
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
unpackedOperands.push_back(subOperands);
}

int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
Expand Down Expand Up @@ -527,16 +411,6 @@ struct ElementwiseInlineAsmOpConversion
// Reorder and pack the results.
SmallVector<Value> outs;
for (int i = 0; i < unpackedResults.size(); i++) {
// We reordered all the inputs so they match operand 0. Reorder the
// outputs accordingly.
if (op->getNumOperands() > 0) {
unpackedResults[i] = reorderValues(
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
/*ouType=*/op->getResult(i).getType());
}
auto dstTy = op->getResult(i).getType();
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
getTypeConverter());
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
rewriter, op->getResult(i).getType()));
}
Expand Down
2 changes: 0 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");
auto inOrd = getOrder(srcSharedLayout);

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(),
Expand All @@ -183,7 +182,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
SmallVector<Value> outVals = loadSharedToDistributed(
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);

outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter);
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

Expand Down
Loading

0 comments on commit 876ce90

Please sign in to comment.