Skip to content

Commit

Permalink
[SYCL] Implement SYCLConditionalCallOnDevicePass pass (#14228)
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-vodopyanov authored Jul 9, 2024
1 parent 1fcc1cf commit 19e471f
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 0 deletions.
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/BackendUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
#include "llvm/SYCLLowerIR/RecordSYCLAspectNames.h"
#include "llvm/SYCLLowerIR/SYCLAddOptLevelAttribute.h"
#include "llvm/SYCLLowerIR/SYCLConditionalCallOnDevice.h"
#include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h"
#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h"
#include "llvm/SYCLLowerIR/UtilsSYCLNativeCPU.h"
Expand Down Expand Up @@ -994,6 +995,9 @@ void EmitAssemblyHelper::RunOptimizationPipeline(
MPM.addPass(ESIMDVerifierPass(LangOpts.SYCLESIMDForceStatelessMem));
if (Level == OptimizationLevel::O0)
MPM.addPass(ESIMDRemoveOptnoneNoinlinePass());
// SYCLConditionalCallOnDevicePass should be run before
// SYCLPropagateAspectsUsagePass
MPM.addPass(SYCLConditionalCallOnDevicePass(LangOpts.SYCLUniquePrefix));
MPM.addPass(SYCLPropagateAspectsUsagePass(
/*FP64ConvEmu=*/CodeGenOpts.FP64ConvEmu,
/*ExcludeAspects=*/{"fp64"}));
Expand Down
37 changes: 37 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/SYCLConditionalCallOnDevice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===-- SYCLConditionalCallOnDevice.h - SYCLConditionalCallOnDevice Pass --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pass performs transformations on functions which represent the conditional
// call to application's callable object. The conditional call is based on the
// SYCL device's aspects or architecture passed to the functions.
//
//===----------------------------------------------------------------------===//
//
#ifndef LLVM_SYCL_CONDITIONAL_CALL_ON_DEVICE_H
#define LLVM_SYCL_CONDITIONAL_CALL_ON_DEVICE_H

#include "llvm/IR/PassManager.h"

#include <string>

namespace llvm {

class SYCLConditionalCallOnDevicePass
: public PassInfoMixin<SYCLConditionalCallOnDevicePass> {
public:
SYCLConditionalCallOnDevicePass(std::string SYCLUniquePrefix = "")
: UniquePrefix(SYCLUniquePrefix) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);

private:
std::string UniquePrefix;
};

} // namespace llvm

#endif // LLVM_SYCL_CONDITIONAL_CALL_ON_DEVICE_H
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
#include "llvm/SYCLLowerIR/MutatePrintfAddrspace.h"
#include "llvm/SYCLLowerIR/RecordSYCLAspectNames.h"
#include "llvm/SYCLLowerIR/SYCLAddOptLevelAttribute.h"
#include "llvm/SYCLLowerIR/SYCLConditionalCallOnDevice.h"
#include "llvm/SYCLLowerIR/SYCLPropagateAspectsUsage.h"
#include "llvm/SYCLLowerIR/SYCLPropagateJointMatrixUsage.h"
#include "llvm/Support/CommandLine.h"
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ MODULE_PASS("sycllowerwglocalmemory", SYCLLowerWGLocalMemoryPass())
MODULE_PASS("lower-esimd-kernel-attrs", SYCLFixupESIMDKernelWrapperMDPass())
MODULE_PASS("esimd-remove-host-code", ESIMDRemoveHostCodePass());
MODULE_PASS("esimd-remove-optnone-noinline", ESIMDRemoveOptnoneNoinlinePass());
MODULE_PASS("sycl-conditional-call-on-device", SYCLConditionalCallOnDevicePass())
MODULE_PASS("sycl-propagate-aspects-usage", SYCLPropagateAspectsUsagePass())
MODULE_PASS("sycl-propagate-joint-matrix-usage", SYCLPropagateJointMatrixUsagePass())
MODULE_PASS("sycl-add-opt-level-attribute", SYCLAddOptLevelAttributePass())
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ add_llvm_component_library(LLVMSYCLLowerIR
MutatePrintfAddrspace.cpp
SpecConstants.cpp
SYCLAddOptLevelAttribute.cpp
SYCLConditionalCallOnDevice.cpp
SYCLDeviceLibReqMask.cpp
SYCLDeviceRequirements.cpp
SYCLKernelParamOptInfo.cpp
Expand Down
150 changes: 150 additions & 0 deletions llvm/lib/SYCLLowerIR/SYCLConditionalCallOnDevice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
//===-- SYCLConditionalCallOnDevice.cpp - SYCLConditionalCallOnDevice Pass
//--===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pass performs transformations on functions which represent the conditional
// call to application's callable object. The conditional call is based on the
// SYCL device's aspects or architecture passed to the functions.
//
//===----------------------------------------------------------------------===//

#include "llvm/SYCLLowerIR/SYCLConditionalCallOnDevice.h"

#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/CommandLine.h"

using namespace llvm;

cl::opt<std::string>
UniquePrefixOpt("sycl-conditional-call-on-device-unique-prefix",
cl::Optional, cl::Hidden,
cl::desc("Set unique prefix for a translation unit, "
"required for funtions with external linkage"),
cl::init(""));

PreservedAnalyses
SYCLConditionalCallOnDevicePass::run(Module &M, ModuleAnalysisManager &) {
// find call_if_on_device_conditionally function
SmallVector<Function *, 4> FCallers;
for (Function &F : M.functions()) {
if (F.isDeclaration())
continue;

if (CallingConv::SPIR_KERNEL == F.getCallingConv())
continue;

if (F.hasFnAttribute("sycl-call-if-on-device-conditionally"))
FCallers.push_back(&F);
}

// A vector instead of DenseMap to make LIT tests predictable
SmallVector<std::pair<Function *, Function *>, 8> FCallersToFActions;
for (Function *FCaller : FCallers) {
// Find call to @CallableXXX in call_if_on_device_conditionally function
// (FAction). FAction should be a literal (i.e. not a pointer). The
// structure of the header file ensures that there is exactly one such
// instruction.
bool CallFound = false;
for (Instruction &I : instructions(FCaller)) {
if (auto *CI = dyn_cast<CallInst>(&I);
CI && (Intrinsic::IndependentIntrinsics::not_intrinsic ==
CI->getIntrinsicID())) {
assert(
!CallFound &&
"The call_if_on_device_conditionally function must have only one "
"call instruction (w/o taking into account any calls to various "
"intrinsics). More than one found.");
FCallersToFActions.push_back(
std::make_pair(FCaller, CI->getCalledFunction()));
CallFound = true;
}
}
assert(CallFound &&
"The call_if_on_device_conditionally function must have a "
"call instruction (w/o taking into account any calls to various "
"intrinsics). Call not found.");
}

int FCallerIndex = 1;
for (const auto &FCallerToFAction : FCallersToFActions) {
Function *FCaller = FCallerToFAction.first;
Function *FAction = FCallerToFAction.second;

// Create a new function type with an additional function pointer argument
SmallVector<Type *, 4> NewParamTypes;
Type *FActionType = FAction->getType();
NewParamTypes.push_back(
PointerType::getUnqual(FActionType)); // Add function pointer to FAction
FunctionType *OldFCallerType = FCaller->getFunctionType();
for (Type *Ty : OldFCallerType->params())
NewParamTypes.push_back(Ty);

auto *NewFCallerType =
FunctionType::get(OldFCallerType->getReturnType(), NewParamTypes,
OldFCallerType->isVarArg());

// Create a new function with the updated type and rename it to
// call_if_on_device_conditionally_GUID_N
if (!UniquePrefixOpt.empty())
UniquePrefix = UniquePrefixOpt;
// Also change to external linkage
auto *NewFCaller =
Function::Create(NewFCallerType, Function::ExternalLinkage,
Twine(FCaller->getName()) + "_" + UniquePrefix + "_" +
Twine(FCallerIndex),
&M);

NewFCaller->setCallingConv(FCaller->getCallingConv());

DenseMap<CallInst *, CallInst *> OldCallsToNewCalls;

// Replace all calls to the old function with the new one
for (auto &U : FCaller->uses()) {
auto *Call = dyn_cast<CallInst>(U.getUser());

if (!Call)
continue;

SmallVector<Value *, 4> Args;
// Add the function pointer as the first argument
Args.push_back(FAction);
for (unsigned I = 0; I < Call->arg_size(); ++I)
Args.push_back(Call->getArgOperand(I));

// Create the new call instruction
auto *NewCall =
CallInst::Create(NewFCaller, Args, /* NameStr = */ "", Call);
NewCall->setCallingConv(Call->getCallingConv());
NewCall->setDebugLoc(Call->getDebugLoc());

OldCallsToNewCalls[Call] = NewCall;
}

for (const auto &OldCallToNewCall : OldCallsToNewCalls) {
auto *OldCall = OldCallToNewCall.first;
auto *NewCall = OldCallToNewCall.second;

// Replace the old call with the new call
OldCall->replaceAllUsesWith(NewCall);
OldCall->eraseFromParent();
}

// Remove the body of the new function
NewFCaller->deleteBody();

// Remove the old function from the module
FCaller->eraseFromParent();

FCallerIndex++;
}

return PreservedAnalyses::none();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: opt -passes=sycl-conditional-call-on-device -sycl-conditional-call-on-device-unique-prefix="PREFIX" < %s -S | FileCheck %s

%class.anon = type { ptr addrspace(4) }
%"struct.std::integer_sequence.3" = type { i8 }

define internal spir_func void @call_if_on_device_conditionally_helper(ptr noundef byval(%class.anon) align 8 %fn, ptr noundef byval(%"struct.std::integer_sequence.3") align 1 %0) #2 !srcloc !0 {
entry:
%agg.tmp = alloca %class.anon, align 8
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
call spir_func void @call_if_on_device_conditionally(ptr noundef byval(%class.anon) align 8 %agg.tmp, i32 noundef -2, i32 noundef 251660032) #9
ret void
}

; CHECK-NOT: call spir_func void @call_if_on_device_conditionally(
; CHECK: call spir_func void @call_if_on_device_conditionally_PREFIX_1(ptr @CallableFunc, ptr %agg.tmp, i32 -2, i32 251660032)

define internal spir_func void @call_if_on_device_conditionally(ptr noundef byval(%class.anon) align 8 %fn, i32 noundef %0, i32 noundef %1) #7 !srcloc !1 {
entry:
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
call spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %fn.ascast) #9
ret void
}

; CHECK-NOT: define internal spir_func void @call_if_on_device_conditionally(

define internal spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %this) #6 align 2 !srcloc !2 {
entry:
ret void
}

; CHECK: declare spir_func void @call_if_on_device_conditionally_PREFIX_1(ptr, ptr, i32, i32)

attributes #2 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #6 = { convergent inlinehint mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #7 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-call-if-on-device-conditionally"="true" }
attributes #9 = { convergent nounwind }

!0 = !{i32 74241}
!1 = !{i32 69449}
!2 = !{i32 835}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
; RUN: opt -passes=sycl-conditional-call-on-device -sycl-conditional-call-on-device-unique-prefix="PREFIX" < %s -S | FileCheck %s --implicit-check-not="{{call|define internal}} spir_func void @call_if_on_device_conditionally{{1|2}}("

%class.anon = type { ptr addrspace(4) }
%"struct.std::integer_sequence.3" = type { i8 }

define internal spir_func void @call_if_on_device_conditionally_helper1(ptr noundef byval(%class.anon) align 8 %fn, ptr noundef byval(%"struct.std::integer_sequence.3") align 1 %0) #2 !srcloc !0 {
entry:
%agg.tmp = alloca %class.anon, align 8
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
call spir_func void @call_if_on_device_conditionally1(ptr noundef byval(%class.anon) align 8 %agg.tmp, i32 noundef -2, i32 noundef 251660032) #9
ret void
}

; CHECK: call spir_func void @call_if_on_device_conditionally1_PREFIX_1(ptr @CallableFunc, ptr %agg.tmp, i32 -2, i32 251660032)

define internal spir_func void @call_if_on_device_conditionally_helper2(ptr noundef byval(%class.anon) align 8 %fn, ptr noundef byval(%"struct.std::integer_sequence.3") align 1 %0) #2 !srcloc !0 {
entry:
%agg.tmp = alloca %class.anon, align 8
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
call spir_func void @call_if_on_device_conditionally2(ptr noundef byval(%class.anon) align 8 %agg.tmp, i32 noundef -2, i32 noundef 251660032) #9
ret void
}

; CHECK: call spir_func void @call_if_on_device_conditionally2_PREFIX_2(ptr @CallableFunc, ptr %agg.tmp, i32 -2, i32 251660032)

define internal spir_func void @call_if_on_device_conditionally1(ptr noundef byval(%class.anon) align 8 %fn, i32 noundef %0, i32 noundef %1) #7 !srcloc !1 {
entry:
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
call spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %fn.ascast) #9
ret void
}

define internal spir_func void @call_if_on_device_conditionally2(ptr noundef byval(%class.anon) align 8 %fn, i32 noundef %0, i32 noundef %1) #7 !srcloc !1 {
entry:
%fn.ascast = addrspacecast ptr %fn to ptr addrspace(4)
call spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %fn.ascast) #9
ret void
}

define internal spir_func void @CallableFunc(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %this) #6 align 2 !srcloc !2 {
entry:
ret void
}

; CHECK: declare spir_func void @call_if_on_device_conditionally1_PREFIX_1(ptr, ptr, i32, i32)
; CHECK: declare spir_func void @call_if_on_device_conditionally2_PREFIX_2(ptr, ptr, i32, i32)

attributes #2 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #6 = { convergent inlinehint mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #7 = { convergent mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-call-if-on-device-conditionally"="true" }
attributes #9 = { convergent nounwind }

!0 = !{i32 74241}
!1 = !{i32 69449}
!2 = !{i32 835}

0 comments on commit 19e471f

Please sign in to comment.