Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 1cf7b1b #2707

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace mlir::triton {

namespace gpu {

SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType);

Type getElementType(Value value);

class MultipleOperandsRange
Expand Down Expand Up @@ -179,8 +176,6 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
subOperands = unpackI32s(subOperands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(subOperands.size());
for (auto v : llvm::enumerate(subOperands))
allOperands[v.index()].push_back(v.value());
Expand All @@ -201,13 +196,7 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
}
it += curr.size();
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals = maybeDeduplicate(op, resultVals);
resultVals =
packI32s(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
Expand Down
73 changes: 8 additions & 65 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,14 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
// MXFP utilities
// -----------------------------------------------------------------------

// Convert one int8, which contain, 2 packed mxfp4 values, into 2 bf16
// standalone values and returns them as a pair for (high 4 bits, low 4 bits).
std::pair<Value, Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter,
Location loc, Value v);
// Convert each value, which is an int8 containing 2 packed mxfp4 values,
// into 2 standalone bf16 values
SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values);

// Scale a mxfp4 value by a given scale.
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);

} // namespace LLVM

/* ------------------------------------ */
Expand Down Expand Up @@ -1397,67 +1401,6 @@ inline Value getStructFromSharedMemoryObject(Location loc,
return llvmStruct;
}

// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
// instructions to pack & unpack sub-word integers. A workaround is to
// store the results of tensors with dot operand encodings in i32 to
// facilitate instructions such as `ldmatrix`.
//
// TODO: Confirm if the problem is still there.
inline bool requiresI32Conversion(Type type) {
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return false;
auto dotOpEnc = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!dotOpEnc)
return false;
auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOpEnc.getParent());
if (!(parent && parent.getVersionMajor() < 3))
return false;
return true;
}

inline SmallVector<Value> packI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
int vecWidth = 32 / eltTy.getIntOrFloatBitWidth();
auto vecTy = vec_ty(eltTy, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecTy);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}

inline SmallVector<Value> unpackI32s(const SmallVector<Value> &inValues,
Type type, RewriterBase &rewriter,
Location loc,
const LLVMTypeConverter *typeConverter) {
if (!requiresI32Conversion(type))
return inValues;
Type eltTy =
typeConverter->convertType(cast<RankedTensorType>(type).getElementType());

SmallVector<Value> outValues;
for (auto v : inValues) {
auto vecTy = vec_ty(eltTy, 32 / eltTy.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecTy);
for (int i = 0; i < 32 / eltTy.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
6 changes: 2 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1212,8 +1212,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isAmpere() const;
bool isHopper() const;

unsigned getElemsPerThreadOfOperand(int opIdx, ArrayRef<int64_t> shape) const;

// Get [isARow, isBRow, isAVec4, isBVec4, id] from versionMinor
std::tuple<bool, bool, bool, bool, int> decodeVoltaLayoutStates() const;

Expand All @@ -1230,8 +1228,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2OrV3RepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth, int opIdx) const;
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;

bool supportReduction() const {
if (isAmpere() || isHopper()) {
Expand Down
11 changes: 3 additions & 8 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
: idx;
outVals[i] = inVals[srcIdx];
}
outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -392,11 +391,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (useLegacyMMAConversion) {
return false;
}
// FIXME [Dot LL]
// Enabling LL path for buggy kWidth path
bool largeKWidth =
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
return largeKWidth && nvidiaMma.isAmpere();
if (nvidiaMma.isAmpere()) {
return true;
}
}
return false;
}
Expand Down Expand Up @@ -440,7 +437,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
inVals[it.index()] = ptrtoint(llvmElemTy, it.value());
}
}
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());

// Pretty sure this is the identity function ATM
// It'd be better to simply call `quotient({kBlock})` and
Expand All @@ -460,7 +456,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

outVals = packI32s(outVals, dstTy, rewriter, loc, getTypeConverter());
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,10 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
auto dstDotOp =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (srcBlocked && dstDotOp) {
// FIXME [Dot LL]
// We support this one via LLs, as the LocalLoad path is buggy
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
bool largeKWidth =
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
if (mma.isAmpere() && largeKWidth) {
return;
}
auto dotParent = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent());
if (dotParent && dotParent.isAmpere()) {
return;
}

Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
auto tmpType = MemDescType::get(
Expand Down
140 changes: 7 additions & 133 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,138 +11,23 @@ using namespace mlir::triton::gpu;

namespace mlir::triton::gpu {

namespace {

bool isDotOpTensorAndPacked(Type srcTy) {
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
if (!tensorTy)
return false;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!encoding)
return false;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
// By code convention, values for Hopper's dotOp-encoded tensors are not
// packed
if (!parentEnc || parentEnc.isHopper())
return false;
return true;
}

} // namespace

Type getElementType(Value value) {
auto type = value.getType();
if (auto tensorType = dyn_cast<RankedTensorType>(type))
return tensorType.getElementType();
return type;
}
// MMA encoding has a different order depending on the element's bit width;
// reorder if we're in this case.
SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
Type ouType) {
auto inTensorTy = dyn_cast<RankedTensorType>(inType);
auto ouTensorTy = dyn_cast<RankedTensorType>(ouType);
if (!inTensorTy || !ouTensorTy)
return values;
auto inEncoding = dyn_cast<DotOperandEncodingAttr>(inTensorTy.getEncoding());
auto ouEncoding = dyn_cast<DotOperandEncodingAttr>(ouTensorTy.getEncoding());
assert(inEncoding == ouEncoding);
if (!inEncoding)
return values;
// If the parent of the dot operand is in block encoding, we don't need to
// reorder elements
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
if (!parentEncoding || parentEncoding.isHopper())
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
// Register layout conversion:
//
// [0, 1], [4, 5] ⟶ [0], [1], [4], [5]
// [2, 3], [6, 7] [2], [3], [6], [7]
//
// Original access order:
//
// [0, 1], [2, 3], [4, 5], [6, 7]
//
// Transformed access order:
//
// [0], [2], [1], [3], [4], [6], [5], [7]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 7]);
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
// Register layout conversion:
//
// [0, 1, 2, 3], [8, 9, 10, 11] ⟶ [0, 1], [2, 3], [8, 9], [10, 11]
// [4, 5, 6, 7], [12, 13, 14, 15] [4, 5], [6, 7], [12, 13], [14, 15]
//
// Original access order:
//
// [0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]
//
// Transformed access order:
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
return ret;
}
llvm_unreachable("unimplemented code path");
}

int getNumElementsPerThreads(Type type,
const LLVMTypeConverter *typeConverter) {
int numElemsPerThread = 1;
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy)
return numElemsPerThread;
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (structType) {
numElemsPerThread = structType.getBody().size();
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (structType)
numElemsPerThread = structType.getBody().size();
}
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
return numElemsPerThread;
auto eltType = tensorTy.getElementType();
assert(eltType.getIntOrFloatBitWidth() <= 32 &&
"Only support element type with bit width <= 32 in dot operand mma "
"layout");
// dot operand data are packed into i32 elements so use the following formula
// to get the number of elements per thread.
return (32 / eltType.getIntOrFloatBitWidth()) * numElemsPerThread;
return numElemsPerThread;
}

} // namespace mlir::triton::gpu
Expand Down Expand Up @@ -473,8 +358,7 @@ struct ElementwiseInlineAsmOpConversion
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto subOperands = unpackLLElements(loc, operand, rewriter);
unpackedOperands.push_back(
unpackI32s(subOperands, argTy, rewriter, loc, getTypeConverter()));
unpackedOperands.push_back(subOperands);
}

int numElemsPerThread = getNumElementsPerThreads(op->getResult(0).getType(),
Expand Down Expand Up @@ -527,16 +411,6 @@ struct ElementwiseInlineAsmOpConversion
// Reorder and pack the results.
SmallVector<Value> outs;
for (int i = 0; i < unpackedResults.size(); i++) {
// We reordered all the inputs so they match operand 0. Reorder the
// outputs accordingly.
if (op->getNumOperands() > 0) {
unpackedResults[i] = reorderValues(
unpackedResults[i], /*inType=*/op->getOperand(0).getType(),
/*ouType=*/op->getResult(i).getType());
}
auto dstTy = op->getResult(i).getType();
unpackedResults[i] = packI32s(unpackedResults[i], dstTy, rewriter, loc,
getTypeConverter());
outs.push_back(packLLElements(loc, getTypeConverter(), unpackedResults[i],
rewriter, op->getResult(i).getType()));
}
Expand Down
2 changes: 0 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");
auto inOrd = getOrder(srcSharedLayout);

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
loc, adaptor.getSrc(),
Expand All @@ -183,7 +182,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
SmallVector<Value> outVals = loadSharedToDistributed(
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);

outVals = packI32s(outVals, dstTy, rewriter, loc, typeConverter);
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

Expand Down
Loading