Skip to content

Commit

Permalink
Add CUDAKernelCallExpr Clone function
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 6, 2024
1 parent c874ca5 commit dc74261
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 24 deletions.
21 changes: 20 additions & 1 deletion include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ static inline IfStmt* IfStmt_Create(const ASTContext &Ctx,
#endif
}

// Compatibility helper function for creation CallExpr.
// Compatibility helper function for creation CallExpr and CUDAKernelCallExpr.
// Clang 12 and above use one extra param.

#if CLANG_VERSION_MAJOR < 12
Expand All @@ -188,13 +188,32 @@ static inline CallExpr* CallExpr_Create(const ASTContext &Ctx, Expr *Fn, ArrayRe
{
return CallExpr::Create(Ctx, Fn, Args, Ty, VK, RParenLoc, MinNumArgs, UsesADL);
}

static inline CUDAKernelCallExpr*
CUDAKernelCallExpr_Create(const ASTContext& Ctx, Expr* Fn, CallExpr* Config,
ArrayRef<Expr*> Args, QualType Ty, ExprValueKind VK,
SourceLocation RParenLoc, unsigned MinNumArgs = 0,
CallExpr::ADLCallKind UsesADL = CallExpr::NotADL) {
return CUDAKernelCallExpr::Create(Ctx, Fn, Config, Args, Ty, VK, RParenLoc,
MinNumArgs);
}
#elif CLANG_VERSION_MAJOR >= 12
static inline CallExpr* CallExpr_Create(const ASTContext &Ctx, Expr *Fn, ArrayRef< Expr *> Args,
QualType Ty, ExprValueKind VK, SourceLocation RParenLoc, FPOptionsOverride FPFeatures,
unsigned MinNumArgs = 0, CallExpr::ADLCallKind UsesADL = CallExpr::NotADL)
{
return CallExpr::Create(Ctx, Fn, Args, Ty, VK, RParenLoc, FPFeatures, MinNumArgs, UsesADL);
}

static inline CUDAKernelCallExpr*
CUDAKernelCallExpr_Create(const ASTContext& Ctx, Expr* Fn, CallExpr* Config,
ArrayRef<Expr*> Args, QualType Ty, ExprValueKind VK,
SourceLocation RParenLoc,
FPOptionsOverride FPFeatures, unsigned MinNumArgs = 0,
CallExpr::ADLCallKind UsesADL = CallExpr::NotADL) {
return CUDAKernelCallExpr::Create(Ctx, Fn, Config, Args, Ty, VK, RParenLoc,
FPFeatures, MinNumArgs);
}
#endif

// Clang 12 and above use one extra param.
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/StmtClone.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ namespace utils {
DECLARE_CLONE_FN(ExtVectorElementExpr)
DECLARE_CLONE_FN(UnaryExprOrTypeTraitExpr)
DECLARE_CLONE_FN(CallExpr)
DECLARE_CLONE_FN(CUDAKernelCallExpr)
DECLARE_CLONE_FN(ShuffleVectorExpr)
DECLARE_CLONE_FN(ExprWithCleanups)
DECLARE_CLONE_FN(CXXOperatorCallExpr)
Expand Down
30 changes: 7 additions & 23 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1930,15 +1930,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diag(DiagnosticsEngine::Error, CE->getEndLoc(),
"Failed to create cudaMemcpy call; cudaMemcpyDeviceToHost not "
"found. Creating kernel pullback aborted.");
for (std::size_t a = 0; a < CE->getNumArgs(); ++a)
CallArgs.push_back(
Clone(CE->getArg(a))); // create a non-const copy
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
Loc, CallArgs, Loc, CUDAExecConfig)
.get();
return StmtDiff(call);
return Clone(CE);
}
CXXScopeSpec SS;
Expr* deviceToHostExpr =
Expand All @@ -1947,20 +1939,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/*ADL=*/false)
.get();
if (!deviceToHostExpr) {
diag(
DiagnosticsEngine::Error, CE->getEndLoc(),
"Failed to create cudaMemcpy call; Failed to create expression "
"for cudaMemcpyDeviceToHost. Creating kernel pullback "
"aborted.");
for (std::size_t a = 0; a < CE->getNumArgs(); ++a)
CallArgs.push_back(
Clone(CE->getArg(a))); // create a non-const copy
Expr* call =
m_Sema
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()),
Loc, CallArgs, Loc, CUDAExecConfig)
.get();
return StmtDiff(call);
diag(DiagnosticsEngine::Error, CE->getEndLoc(),
"Failed to create cudaMemcpy call; Failed to create "
"expression "
"for cudaMemcpyDeviceToHost. Creating kernel pullback "
"aborted.");
return Clone(CE);
}

// Add calls to cudaMalloc, cudaMemset, cudaMemcpy, and cudaFree
Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,21 @@ Stmt* StmtClone::VisitCallExpr(CallExpr* Node) {
return result;
}

Stmt* StmtClone::VisitCUDAKernelCallExpr(CUDAKernelCallExpr* Node) {
CUDAKernelCallExpr* result = clad_compat::CUDAKernelCallExpr_Create(
Ctx, Clone(Node->getCallee()), Clone(Node->getConfig()),
llvm::ArrayRef<Expr*>(), CloneType(Node->getType()), Node->getValueKind(),
Node->getRParenLoc() CLAD_COMPAT_CLANG8_CallExpr_ExtraParams);
result->setNumArgsUnsafe(Node->getNumArgs());
for (unsigned i = 0, e = Node->getNumArgs(); i < e; ++i)
result->setArg(i, Clone(Node->getArg(i)));

// Copy Value and Type dependent
clad_compat::ExprSetDeps(result, Node);

return result;
}

Stmt* StmtClone::VisitUnresolvedLookupExpr(UnresolvedLookupExpr* Node) {
TemplateArgumentListInfo TemplateArgs;
if (Node->hasExplicitTemplateArgs())
Expand Down

0 comments on commit dc74261

Please sign in to comment.