Skip to content

Commit

Permalink
Merge commit '19d14209ad667d89ae9b2dedfd0a82512354d0a3'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Apr 30, 2024
2 parents 0ddbd2f + 19d1420 commit 689ec3c
Show file tree
Hide file tree
Showing 50 changed files with 391 additions and 334 deletions.
3 changes: 1 addition & 2 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
Expand Down Expand Up @@ -27,7 +27,6 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/InitAllPasses.h"
#include "triton/Tools/Sys/GetEnv.hpp"

namespace mlir {
namespace test {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "TargetInfoBase.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

using namespace mlir;
using namespace mlir::triton;

Expand Down Expand Up @@ -33,6 +33,7 @@ void populateElementwiseOpToLLVMPatterns(
PatternBenefit benefit);

void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);

Expand All @@ -42,6 +43,7 @@ void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
PatternBenefit benefit);

void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
PatternBenefit benefit);

Expand Down
2 changes: 2 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class TargetInfoBase {
public:
virtual bool supportMaximumMinimum() const = 0;

virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0;

virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc,
Type type, Value cmp) const = 0;

Expand Down
92 changes: 33 additions & 59 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/NVGPU/IR/Dialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -59,8 +59,6 @@ using namespace mlir::triton;
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
#define store(...) rewriter.create<LLVM::StoreOp>(loc, __VA_ARGS__)
#define load_dsmem(...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__)
#define store_dsmem(...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__)
#define fcmp_ogt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::ogt, lhs, rhs)
Expand Down Expand Up @@ -222,29 +220,6 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value);

/// Usage of macro load_dsmem
/// (1) load_dsmem(addr, ctaId)
/// (2) load_dsmem(addr, ctaId, vec)
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Type elemTy);
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
Value addr, Value ctaId, unsigned vec,
Type elemTy);

/// Usage of macro store_dsmem
/// (1) store_dsmem(addr, ctaId, value, pred)
/// (2) store_dsmem(addr, ctaId, value)
/// (3) store_dsmem(addr, ctaId, values, pred)
/// (4) store_dsmem(addr, ctaId, values)
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value, Value pred);
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, Value value);
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values, Value pred);
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
Value ctaId, ArrayRef<Value> values);

/// Helper function to get strides from a given shape and its order
SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
ArrayRef<unsigned> order,
Expand Down Expand Up @@ -354,6 +329,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
// smallest CTA tile that is common between input and output layouts.
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTATile);
Expand Down Expand Up @@ -416,11 +392,6 @@ inline Value getThreadId(RewriterBase &rewriter, Location loc) {
return tid;
}

inline Value getClusterCTAId(RewriterBase &rewriter, Location loc) {
return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(loc,
rewriter.getI32Type());
}

// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
Expand Down Expand Up @@ -1023,6 +994,7 @@ emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,

inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
Attribute layout,
ArrayRef<int64_t> shape) {
unsigned rank = shape.size();
Expand All @@ -1033,7 +1005,7 @@ inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
triton::gpu::getShapePerCTA(CTASplitNum, shape);

// Delinearize clusterCTAId
Value clusterCTAId = getClusterCTAId(rewriter, loc);
Value clusterCTAId = target.getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);

Expand All @@ -1051,11 +1023,10 @@ inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
return CTAOffset;
}

inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
RewriterBase &rewriter,
Attribute layout,
RankedTensorType type,
bool withCTAOffset) {
inline SmallVector<Value>
emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, Attribute layout,
RankedTensorType type, bool withCTAOffset) {
auto shape = type.getShape();

SmallVector<Value> baseIndex;
Expand All @@ -1080,16 +1051,17 @@ inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy =
RankedTensorType::get(parentShape, type.getElementType(), parentLayout);
result = emitBaseIndexForLayoutImpl(loc, rewriter, parentLayout, parentTy,
withCTAOffset);
result = emitBaseIndexForLayoutImpl(loc, rewriter, target, parentLayout,
parentTy, withCTAOffset);
result.erase(result.begin() + sliceLayout.getDim());
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
return result;
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
if (withCTAOffset) {
auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, layout, shape);
auto CTAOffset =
emitCTAOffsetForLayout(loc, rewriter, target, layout, shape);
assert(CTAOffset.size() == result.size() && "Rank mismatch");
for (unsigned k = 0; k < result.size(); ++k) {
// Individual elements of `result` may be null. In the caller
Expand All @@ -1104,10 +1076,11 @@ inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
}

inline SmallVector<Value>
emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout,
emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, Attribute layout,
RankedTensorType type, bool withCTAOffset) {
SmallVector<Value> idx =
emitBaseIndexForLayoutImpl(loc, rewriter, layout, type, withCTAOffset);
SmallVector<Value> idx = emitBaseIndexForLayoutImpl(
loc, rewriter, target, layout, type, withCTAOffset);

// Check that any null values were sliced out.
for (Value v : idx) {
Expand Down Expand Up @@ -1151,11 +1124,11 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
inline SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, Attribute layout,
RankedTensorType type, bool withCTAOffset) {
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset) {
// step 1, delinearize threadId to get the base index
auto multiDimBase =
emitBaseIndexForLayout(loc, rewriter, layout, type, withCTAOffset);
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, target, layout,
type, withCTAOffset);
// step 2, get offset of each element
auto offset = emitOffsetForLayout(layout, type);
// step 3, add offset to base, and reorder the sequence
Expand All @@ -1175,9 +1148,9 @@ emitIndices(Location loc, RewriterBase &rewriter, Attribute layout,
/* ---------------- */
/* ---------------- */
inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
Location loc, unsigned inVec, RankedTensorType srcTy,
triton::gpu::SharedEncodingAttr resSharedLayout, Type resElemTy,
SharedMemoryObject smemObj, RewriterBase &rewriter,
Location loc, const TargetInfoBase &target, unsigned inVec,
RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout,
Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter,
SmallVectorImpl<Value> &offsetVals, SmallVectorImpl<Value> &srcStrides) {
// This utility computes the pointers for accessing the provided swizzled
// shared memory layout `resSharedLayout`. More specifically, it computes,
Expand Down Expand Up @@ -1224,7 +1197,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
outVec * maxPhase <= srcShape[outOrder[0]] &&
"Swizzling would generate out of bounds memory accesses");
// Tensor indices held by the current thread, as LLVM values
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false);
auto srcIndices =
emitIndices(loc, rewriter, target, srcEncoding, srcTy, false);
// Swizzling with leading offsets (e.g. Hopper GMMA)
unsigned swizzlingByteWidth = 0;
if (resSharedLayout.getHasLeadingOffset()) {
Expand Down Expand Up @@ -1336,10 +1310,9 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
return ret;
}

inline SmallVector<Value>
loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
Type elemTy, Location loc,
ConversionPatternRewriter &rewriter) {
inline SmallVector<Value> loadSharedToDistributed(
Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc,
ConversionPatternRewriter &rewriter, const TargetInfoBase &target) {
auto dstTy = cast<RankedTensorType>(dst.getType());
auto dstShape = dstTy.getShape();
assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed");
Expand Down Expand Up @@ -1373,7 +1346,7 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
SmallVector<Value> offsetVals = {smemObj.strides.size(), i32_val(0)};

DenseMap<unsigned, Value> sharedPtrs =
getSwizzledSharedPtrs(loc, outVec, dstTy, srcSharedLayout, elemTy,
getSwizzledSharedPtrs(loc, target, outVec, dstTy, srcSharedLayout, elemTy,
smemObj, rewriter, offsetVals, smemObj.strides);
assert(outElems % minVec == 0 && "Unexpected number of elements");
unsigned numVecs = outElems / minVec;
Expand All @@ -1395,7 +1368,8 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
inline void storeDistributedToShared(Value src, ArrayRef<Value> inVals,
ArrayRef<Value> dstStrides, Value dst,
Value smemBase, Type elemTy, Location loc,
ConversionPatternRewriter &rewriter) {
ConversionPatternRewriter &rewriter,
const TargetInfoBase &target) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto srcShape = srcTy.getShape();
auto rank = srcShape.size();
Expand Down Expand Up @@ -1432,8 +1406,8 @@ inline void storeDistributedToShared(Value src, ArrayRef<Value> inVals,
SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals);

DenseMap<unsigned, Value> sharedPtrs =
getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, elemTy, smemObj,
rewriter, offsetVals, srcStrides);
getSwizzledSharedPtrs(loc, target, inVec, srcTy, dstSharedLayout, elemTy,
smemObj, rewriter, offsetVals, srcStrides);
LDBG("storeDistributedToShared: numElems = " << numElems << " minVec = "
<< minVec << " " << wordTy);
for (unsigned i = 0; i < numElems; ++i) {
Expand Down
1 change: 0 additions & 1 deletion include/triton/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ add_subdirectory(TritonGEN)
add_subdirectory(TritonGPU)
add_subdirectory(TritonIntelGPU)
add_subdirectory(TritonNvidiaGPU)
add_subdirectory(NVGPU)
2 changes: 0 additions & 2 deletions include/triton/Dialect/NVGPU/CMakeLists.txt

This file was deleted.

1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
Expand Down
32 changes: 20 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

#include "triton/Analysis/Allocation.h"
Expand Down Expand Up @@ -25,8 +27,11 @@ namespace {
struct LocalLoadOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalLoadOp> {
public:
using ConvertOpToLLVMPattern<
triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern;
LocalLoadOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

LogicalResult
matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -93,25 +98,28 @@ struct LocalLoadOpConversion
auto srcStrides =
getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter);

SmallVector<Value> outVals = loadSharedToDistributed(
op.getResult(), op.getSrc(), smemObj, elemTy, loc, rewriter);
SmallVector<Value> outVals =
loadSharedToDistributed(op.getResult(), op.getSrc(), smemObj, elemTy,
loc, rewriter, targetInfo);

Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
rewriter.replaceOp(op, result);

return success();
}

private:
const TargetInfoBase &targetInfo;
};

struct ConvertLayoutOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(typeConverter,
benefit),
targetInfo(targetInfo) {}
ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

LogicalResult
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -179,7 +187,7 @@ struct ConvertLayoutOpConversion
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type,
getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type,
multiDimCTAInRepId, shapePerCTATile);
SmallVector<Value> multiDimOffsetWrapped = getWrappedMultiDimOffset(
rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile,
Expand Down Expand Up @@ -315,5 +323,5 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
}
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ struct HistogramOpConversion
LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
auto dstType = op.getType();
Attribute dstEncoding = dstType.getEncoding();
auto indices =
emitIndices(op.getLoc(), rewriter, dstEncoding, dstType, true);
auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding,
dstType, true);
SmallVector<Value> innerDimIndices;
for (int i = 0; i < indices.size(); ++i)
innerDimIndices.push_back(indices[i][0]);
Expand Down
Loading

0 comments on commit 689ec3c

Please sign in to comment.