Skip to content

Commit

Permalink
[flang][cuda] Specialize entry point for scalar to desc data transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
clementval committed Nov 16, 2024
1 parent 131d73e commit 7054e5c
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 16 deletions.
4 changes: 4 additions & 0 deletions flang/include/flang/Runtime/CUDA/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

/// Data transfer from a scalar descriptor to a descriptor.
void RTDECL(CUFDataTransferCstDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

/// Data transfer from a descriptor to a descriptor.
void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
Expand Down
8 changes: 6 additions & 2 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,9 @@ struct CUFDataTransferOpConversion
// until we have more infrastructure.
mlir::Value src = emboxSrc(rewriter, op, symtab);
mlir::Value dst = emboxDst(rewriter, op, symtab);
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
CUFDataTransferDescDescNoRealloc)>(loc, builder);
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
Expand Down Expand Up @@ -648,6 +649,9 @@ struct CUFDataTransferOpConversion
mlir::Value src = op.getSrc();
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
src = emboxSrc(rewriter, op, symtab);
if (fir::isa_trivial(srcTy))
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
loc, builder);
}
auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value {
if (mlir::isa<fir::EmboxOp>(val.getDefiningOp())) {
Expand Down
19 changes: 19 additions & 0 deletions flang/runtime/CUDA/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "flang/Runtime/CUDA/memory.h"
#include "../assign-impl.h"
#include "../terminator.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
Expand Down Expand Up @@ -120,6 +121,24 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}

void RTDECL(CUFDataTransferCstDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
unsigned mode, const char *sourceFile, int sourceLine) {
MemmoveFct memmoveFct;
Terminator terminator{sourceFile, sourceLine};
if (mode == kHostToDevice) {
memmoveFct = &MemmoveHostToDevice;
} else if (mode == kDeviceToHost) {
memmoveFct = &MemmoveDeviceToHost;
} else if (mode == kDeviceToDevice) {
memmoveFct = &MemmoveDeviceToDevice;
} else {
terminator.Crash("host to host copy not supported");
}

Fortran::runtime::DoFromSourceAssign(
*dstDesc, *srcDesc, terminator, memmoveFct);
}

void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
int sourceLine) {
Expand Down
17 changes: 15 additions & 2 deletions flang/runtime/assign-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,29 @@
#ifndef FORTRAN_RUNTIME_ASSIGN_IMPL_H_
#define FORTRAN_RUNTIME_ASSIGN_IMPL_H_

#include "flang/Runtime/freestanding-tools.h"

namespace Fortran::runtime {
class Descriptor;
class Terminator;

using MemmoveFct = void *(*)(void *, const void *, std::size_t);

// Assign one object to another via allocate statement from source specifier.
// Note that if allocate object and source expression have the same rank, the
// value of the allocate object becomes the value provided; otherwise the value
// of each element of allocate object becomes the value provided (9.7.1.2(7)).
RT_API_ATTRS void DoFromSourceAssign(
Descriptor &, const Descriptor &, Terminator &);
#ifdef RT_DEVICE_COMPILATION
static RT_API_ATTRS void *MemmoveWrapper(
void *dest, const void *src, std::size_t count) {
return Fortran::runtime::memmove(dest, src, count);
}
RT_API_ATTRS void DoFromSourceAssign(Descriptor &, const Descriptor &,
Terminator &, MemmoveFct memmoveFct = &MemmoveWrapper);
#else
RT_API_ATTRS void DoFromSourceAssign(Descriptor &, const Descriptor &,
Terminator &, MemmoveFct memmoveFct = &Fortran::runtime::memmove);
#endif

} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_ASSIGN_IMPL_H_
12 changes: 6 additions & 6 deletions flang/runtime/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ RT_API_ATTRS void Assign(Descriptor &to, const Descriptor &from,

RT_OFFLOAD_API_GROUP_BEGIN

RT_API_ATTRS void DoFromSourceAssign(
Descriptor &alloc, const Descriptor &source, Terminator &terminator) {
RT_API_ATTRS void DoFromSourceAssign(Descriptor &alloc,
const Descriptor &source, Terminator &terminator, MemmoveFct memmoveFct) {
if (alloc.rank() > 0 && source.rank() == 0) {
// The value of each element of allocate object becomes the value of source.
DescriptorAddendum *allocAddendum{alloc.Addendum()};
Expand All @@ -523,17 +523,17 @@ RT_API_ATTRS void DoFromSourceAssign(
alloc.IncrementSubscripts(allocAt)) {
Descriptor allocElement{*Descriptor::Create(*allocDerived,
reinterpret_cast<void *>(alloc.Element<char>(allocAt)), 0)};
Assign(allocElement, source, terminator, NoAssignFlags);
Assign(allocElement, source, terminator, NoAssignFlags, memmoveFct);
}
} else { // intrinsic type
for (std::size_t n{alloc.Elements()}; n-- > 0;
alloc.IncrementSubscripts(allocAt)) {
Fortran::runtime::memmove(alloc.Element<char>(allocAt),
source.raw().base_addr, alloc.ElementBytes());
memmoveFct(alloc.Element<char>(allocAt), source.raw().base_addr,
alloc.ElementBytes());
}
}
} else {
Assign(alloc, source, terminator, NoAssignFlags);
Assign(alloc, source, terminator, NoAssignFlags, memmoveFct);
}
}

Expand Down
12 changes: 6 additions & 6 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func.func @_QPsub2() {
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

func.func @_QPsub3() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
Expand All @@ -58,7 +58,7 @@ func.func @_QPsub3() {
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

func.func @_QPsub4() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
Expand Down Expand Up @@ -297,7 +297,7 @@ func.func @_QPscalar_to_array() {
}

// CHECK-LABEL: func.func @_QPscalar_to_array()
// CHECK: _FortranACUFDataTransferDescDescNoRealloc
// CHECK: _FortranACUFDataTransferCstDesc

func.func @_QPtest_type() {
%0 = cuf.alloc !fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_typeEa"} -> !fir.ref<!fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}>>
Expand Down Expand Up @@ -344,7 +344,7 @@ func.func @_QPshape_shift() {
}

// CHECK-LABEL: func.func @_QPshape_shift()
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc
// CHECK: fir.call @_FortranACUFDataTransferCstDesc

func.func @_QPshape_shift2() {
%c11 = arith.constant 11 : index
Expand Down Expand Up @@ -383,7 +383,7 @@ func.func @_QPdevice_addr_conv() {
// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc
// CHECK: fir.call @_FortranACUFDataTransferCstDesc

func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -464,6 +464,6 @@ func.func @_QPlogical_cst() {
// CHECK: %[[EMBOX:.*]] = fir.embox %[[CONST]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
// CHECK: fir.store %[[EMBOX]] to %[[DESC]] : !fir.ref<!fir.box<!fir.logical<4>>>
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<4>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none

} // end of module

0 comments on commit 7054e5c

Please sign in to comment.