Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
skatrak committed Oct 30, 2024
1 parent 8f88a5b commit a3727a0
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 15 deletions.
4 changes: 2 additions & 2 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4035,8 +4035,6 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
bool IsSigned, bool InclusiveStop, const Twine &Name) {
updateToLocation(Loc);

// Consider the following difficulties (assuming 8-bit signed integers):
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
// DO I = 1, 100, 50
Expand All @@ -4048,6 +4046,8 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
assert(IndVarTy == Stop->getType() && "Stop type mismatch");
assert(IndVarTy == Step->getType() && "Step type mismatch");

updateToLocation(Loc);

ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
ConstantInt *One = ConstantInt::get(IndVarTy, 1);

Expand Down
16 changes: 3 additions & 13 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,8 +1427,7 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
}

TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
IRBuilder<> Builder(BB);
Expand All @@ -1444,17 +1443,8 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopBounds) {
Value *StartVal = ConstantInt::get(LCTy, Start);
Value *StopVal = ConstantInt::get(LCTy, Stop);
Value *StepVal = ConstantInt::get(LCTy, Step);
auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
return Error::success();
};
Expected<CanonicalLoopInfo *> LoopResult =
OMPBuilder.createCanonicalLoop(Loc, LoopBodyGenCB, StartVal, StopVal,
StepVal, IsSigned, InclusiveStop);
assert(LoopResult && "unexpected error");
CanonicalLoopInfo *Loop = *LoopResult;
Loop->assertOK();
Builder.restoreIP(Loop->getAfterIP());
Value *TripCount = Loop->getTripCount();
Value *TripCount = OMPBuilder.calculateCanonicalLoopTripCount(
Loc, StartVal, StopVal, StepVal, IsSigned, InclusiveStop);
return cast<ConstantInt>(TripCount)->getValue().getZExtValue();
};

Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Target/LLVMIR/omptarget-host-eval.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
llvm.func @omp_target_region_() {
%out_teams = llvm.mlir.constant(1000 : i32) : i32
%out_threads = llvm.mlir.constant(2000 : i32) : i32
%out_lb = llvm.mlir.constant(0 : i32) : i32
%out_ub = llvm.mlir.constant(3000 : i32) : i32
%out_step = llvm.mlir.constant(1 : i32) : i32

omp.target
host_eval(%out_teams -> %teams, %out_threads -> %threads,
%out_lb -> %lb, %out_ub -> %ub, %out_step -> %step :
i32, i32, i32, i32, i32) {
omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
llvm.return
}
}

// CHECK-LABEL: define void @omp_target_region_
// CHECK: %[[ARGS:.*]] = alloca %struct.__tgt_kernel_arguments

// CHECK: %[[TRIPCOUNT_ADDR:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[ARGS]], i32 0, i32 8
// CHECK: store i64 3000, ptr %[[TRIPCOUNT_ADDR]]

// CHECK: %[[TEAMS_ADDR:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[ARGS]], i32 0, i32 10
// CHECK: store [3 x i32] [i32 1000, i32 0, i32 0], ptr %[[TEAMS_ADDR]]

// CHECK: %[[THREADS_ADDR:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[ARGS]], i32 0, i32 11
// CHECK: store [3 x i32] [i32 2000, i32 0, i32 0], ptr %[[THREADS_ADDR]]

// CHECK: call i32 @__tgt_target_kernel(ptr @{{.*}}, i64 {{.*}}, i32 1000, i32 2000, ptr @{{.*}}, ptr %[[ARGS]])

0 comments on commit a3727a0

Please sign in to comment.