Skip to content

Commit

Permalink
[flang][cuda] Add verifier for cuda_alloc/cuda_free (llvm#90983)
Browse files Browse the repository at this point in the history
Adding a verifier to check the associated cuda attribute.
  • Loading branch information
clementval authored May 3, 2024
1 parent a4d1026 commit f8a9973
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3364,6 +3364,8 @@ def fir_CUDAAllocOp : fir_Op<"cuda_alloc", [AttrSizedOperandSegments,
CArg<"mlir::ValueRange", "{}">:$typeparams,
CArg<"mlir::ValueRange", "{}">:$shape,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>];

let hasVerifier = 1;
}

def fir_CUDAFreeOp : fir_Op<"cuda_free", [MemoryEffects<[MemFree]>]> {
Expand All @@ -3381,6 +3383,8 @@ def fir_CUDAFreeOp : fir_Op<"cuda_free", [MemoryEffects<[MemFree]>]> {
);

let assemblyFormat = "$devptr `:` qualified(type($devptr)) attr-dict";

let hasVerifier = 1;
}

#endif
13 changes: 13 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4048,6 +4048,19 @@ void fir::CUDAAllocOp::build(
result.addAttributes(attributes);
}

template <typename Op>
static mlir::LogicalResult checkCudaAttr(Op op) {
if (op.getCudaAttr() == fir::CUDADataAttribute::Device ||
op.getCudaAttr() == fir::CUDADataAttribute::Managed ||
op.getCudaAttr() == fir::CUDADataAttribute::Unified)
return mlir::success();
return op.emitOpError("expect device, managed or unified cuda attribute");
}

mlir::LogicalResult fir::CUDAAllocOp::verify() { return checkCudaAttr(*this); }

mlir::LogicalResult fir::CUDAFreeOp::verify() { return checkCudaAttr(*this); }

//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions flang/test/Fir/cuf-invalid.fir
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,21 @@ func.func @_QPsub1() {
%13 = fir.cuda_deallocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {cuda_attr = #fir.cuda<device>} -> i32
return
}

// -----

func.func @_QPsub1() {
// expected-error@+1{{'fir.cuda_alloc' op expect device, managed or unified cuda attribute}}
%0 = fir.cuda_alloc f32 {bindc_name = "r", cuda_attr = #fir.cuda<pinned>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
fir.cuda_free %0 : !fir.ref<f32> {cuda_attr = #fir.cuda<constant>}
return
}

// -----

func.func @_QPsub1() {
%0 = fir.cuda_alloc f32 {bindc_name = "r", cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
// expected-error@+1{{'fir.cuda_free' op expect device, managed or unified cuda attribute}}
fir.cuda_free %0 : !fir.ref<f32> {cuda_attr = #fir.cuda<constant>}
return
}

0 comments on commit f8a9973

Please sign in to comment.