Skip to content

Commit

Permalink
[MLIR][NVVM] Add Op for TMA Prefetch (#116232)
Browse files Browse the repository at this point in the history
PR #115527 adds intrinsics for TMA prefetch.
This patch adds an NVVM Dialect Op for the same.

Lit tests to verify the lowering to LLVM intrinsics as well as
verifier tests (for invalid cases) are added.

PTX Spec reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
  • Loading branch information
durga4github authored Nov 15, 2024
1 parent 7b54976 commit 1b23ebe
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 10 deletions.
68 changes: 68 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1949,6 +1949,74 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
}];
}

def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins
LLVM_AnyPointer:$tmaDescriptor,
Variadic<I32>:$coordinates,
Variadic<I16>:$im2colOffsets,
Optional<I64>:$l2CacheHint);

let description = [{
Initiates an asynchronous prefetch operation on the tensor data from global
memory to L2 cache.

The Op has two modes:
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
layout is preserved at the destination.

2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
the elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional.

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.

[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor)
}];

let assemblyFormat = [{
$tmaDescriptor `,`
`box` `[`$coordinates `]`
(`im2col` `[` $im2colOffsets^ `]` )?
(`l2_cache_hint` `=` $l2CacheHint^ )?
attr-dict `:` type($tmaDescriptor)
}];

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
}];

let hasVerifier = 1;

string llvmBuilder = [{
// Arguments to the intrinsic:
// tmaDesc, tensorDims, im2colOffsets
// cache_hint(if applicable) and flag(boolean)
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($tmaDescriptor);

for (auto v : op.getCoordinates())
translatedOperands.push_back(moduleTranslation.lookupValue(v));

for (auto v : op.getIm2colOffsets())
translatedOperands.push_back(moduleTranslation.lookupValue(v));

llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));

bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
translatedOperands.push_back(builder.getInt1(isCacheHint));

auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
createIntrinsicCall(builder, intId, translatedOperands);
}];
}

//===----------------------------------------------------------------------===//
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//
Expand Down
57 changes: 48 additions & 9 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,32 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {

void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
if (getCoordinates().empty() || getCoordinates().size() > 5)
return emitError("expects coordinates between 1 to 5 dimension");

// Check for im2col mode
if (!getIm2colOffsets().empty()) {
if (getCoordinates().size() < 3)
// This verifier is shared across:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
size_t numIm2ColOffsets,
Location loc) {
if (tensorDims < 1 || tensorDims > 5)
return emitError(loc, "expects coordinates between 1 to 5 dimension");

if (numIm2ColOffsets) {
if (tensorDims < 3)
return emitError(
loc,
"to use im2col mode, the tensor has to be at least 3-dimensional");
if (getCoordinates().size() != (getIm2colOffsets().size() + 2))
if (tensorDims != (numIm2ColOffsets + 2))
return emitError(
"im2col offsets must be 2 less than number of coordinates");
loc, "im2col offsets must be 2 less than number of coordinates");
}
return success();
}

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
getIm2colOffsets().size(), getLoc());
}

LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
if (getCoordinates().size() > 5)
return emitError("Maximum 5 coordinates and dimension is supported.");
Expand All @@ -108,6 +118,11 @@ LogicalResult CpAsyncOp::verify() {
return success();
}

LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
getIm2colOffsets().size(), getLoc());
}

// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
Expand Down Expand Up @@ -1055,6 +1070,30 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}

llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
bool isIm2Col) {
switch (tensorDims) {
case 1:
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
case 2:
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
case 3:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
case 4:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
case 5:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
default:
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
}
}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 25 additions & 1 deletion mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,28 @@ llvm.func @nvvm_fence_proxy_release() {
// expected-error @below {{'nvvm.fence.proxy.release' op uni-directional proxies only support tensormap for to_proxy attribute}}
nvvm.fence.proxy.release #nvvm.mem_scope<cta> from_proxy=#nvvm.proxy_kind<generic> to_proxy=#nvvm.proxy_kind<generic>
llvm.return
}
}

// -----

llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// expected-error @below {{expects coordinates between 1 to 5 dimension}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
llvm.return
}

// -----

llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
// expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// -----

llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
// expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
llvm.return
}
62 changes: 62 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,65 @@ llvm.func @nvvm_breakpoint() {
nvvm.breakpoint
llvm.return
}

// -----

// CHECK-LABEL: @tma_prefetch_1d
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_2d
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_3d
llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_4d
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

// CHECK-LABEL: @tma_prefetch_5d
llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr

// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 undef, i1 false)
// CHECK-LLVM: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}

0 comments on commit 1b23ebe

Please sign in to comment.