diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index 0072c958b517..9faf6ee21802 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -256,7 +256,6 @@ struct EnzymeFunctionLikeAttrInfo : public ParsedAttrInfo { AttrHandling handleDeclAttribute(Sema &S, Decl *D, const ParsedAttr &Attr) const override { - auto FD = cast(D); if (Attr.getNumArgs() != 1) { unsigned ID = S.getDiagnostics().getCustomDiagID( DiagnosticsEngine::Error, @@ -274,92 +273,9 @@ struct EnzymeFunctionLikeAttrInfo : public ParsedAttrInfo { return AttributeNotApplied; } - // if (FD->isLateTemplateParsed()) return; - auto &AST = S.getASTContext(); - DeclContext *declCtx = FD->getDeclContext(); - for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { - if (tmpCtx->isRecord()) { - declCtx = tmpCtx->getParent(); - } - } - auto loc = FD->getLocation(); - RecordDecl *RD; - if (S.getLangOpts().CPlusPlus) - RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - else - RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - RD->setAnonymousStructOrUnion(true); - RD->setImplicit(); - RD->startDefinition(); - auto Tinfo = nullptr; - auto Tinfo0 = nullptr; - auto FT = AST.getPointerType(FD->getType()); - auto CharTy = AST.getIntTypeForBitwidth(8, false); - auto FD0 = FieldDecl::Create(AST, RD, loc, loc, /*Ud*/ nullptr, FT, Tinfo0, - /*expr*/ nullptr, /*mutable*/ true, - /*inclassinit*/ ICIS_NoInit); - FD0->setAccess(AS_public); - RD->addDecl(FD0); - auto FD1 = FieldDecl::Create( - AST, RD, loc, loc, /*Ud*/ nullptr, AST.getPointerType(CharTy), Tinfo0, - /*expr*/ nullptr, /*mutable*/ true, /*inclassinit*/ ICIS_NoInit); - FD1->setAccess(AS_public); - RD->addDecl(FD1); - RD->completeDefinition(); - assert(RD->getDefinition()); - auto &Id = AST.Idents.get("__enzyme_function_like_autoreg_" + - FD->getNameAsString()); - auto T = AST.getRecordType(RD); - auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, T, Tinfo, SC_None); - V->setStorageClass(SC_PrivateExtern); - V->addAttr(clang::UsedAttr::CreateImplicit(AST)); - TemplateArgumentListInfo *TemplateArgs = nullptr; - auto DR = DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), loc, FD, false, - loc, FD->getType(), ExprValueKind::VK_LValue, - FD, TemplateArgs); -#if LLVM_VERSION_MAJOR >= 13 - auto rval = ExprValueKind::VK_PRValue; -#else - auto rval = ExprValueKind::VK_RValue; -#endif - StringRef cstr = Literal->getString(); - Expr *exprs[2] = { -#if LLVM_VERSION_MAJOR >= 12 - ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, DR, - nullptr, rval, FPOptionsOverride()), - ImplicitCastExpr::Create( - AST, AST.getPointerType(CharTy), CastKind::CK_ArrayToPointerDecay, - StringLiteral::Create( - AST, cstr, stringkind, - /*Pascal*/ false, - AST.getStringLiteralArrayType(CharTy, cstr.size()), loc), - nullptr, rval, FPOptionsOverride()) -#else - ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, DR, - nullptr, rval), - ImplicitCastExpr::Create( - AST, AST.getPointerType(CharTy), CastKind::CK_ArrayToPointerDecay, - StringLiteral::Create( - AST, cstr, stringkind, - /*Pascal*/ false, - AST.getStringLiteralArrayType(CharTy, cstr.size()), loc), - nullptr, rval) -#endif - }; - auto IL = new (AST) InitListExpr(AST, loc, exprs, loc); - V->setInit(IL); - IL->setType(T); - if (IL->isValueDependent()) { - unsigned ID = S.getDiagnostics().getCustomDiagID( - DiagnosticsEngine::Error, "use of attribute 'enzyme_function_like' " - "in a templated context not yet supported"); - S.Diag(Attr.getLoc(), ID); - return AttributeNotApplied; - } - S.MarkVariableReferenced(loc, V); - S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); + D->addAttr(AnnotateAttr::Create( + S.Context, ("enzyme_function_like=" + Literal->getString()).str(), + nullptr, 0, Attr.getRange())); return AttributeApplied; } }; @@ -409,71 +325,8 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo { return AttributeNotApplied; } - auto &AST = S.getASTContext(); - DeclContext *declCtx = D->getDeclContext(); - for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { - if (tmpCtx->isRecord()) { - declCtx = tmpCtx->getParent(); - } - } - auto loc = D->getLocation(); - RecordDecl *RD; - if (S.getLangOpts().CPlusPlus) - RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - else - RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - RD->setAnonymousStructOrUnion(true); - RD->setImplicit(); - RD->startDefinition(); - auto T = isa(D) ? cast(D)->getType() - : cast(D)->getType(); - auto Name = isa(D) ? cast(D)->getNameAsString() - : cast(D)->getNameAsString(); - auto FT = AST.getPointerType(T); - auto subname = isa(D) ? "inactivefn" : "inactive_global"; - auto &Id = AST.Idents.get( - (StringRef("__enzyme_") + subname + "_autoreg_" + Name).str()); - auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None); - V->setStorageClass(SC_PrivateExtern); - V->addAttr(clang::UsedAttr::CreateImplicit(AST)); - TemplateArgumentListInfo *TemplateArgs = nullptr; - auto DR = DeclRefExpr::Create( - AST, NestedNameSpecifierLoc(), loc, cast(D), false, loc, T, - ExprValueKind::VK_LValue, cast(D), TemplateArgs); -#if LLVM_VERSION_MAJOR >= 13 - auto rval = ExprValueKind::VK_PRValue; -#else - auto rval = ExprValueKind::VK_RValue; -#endif - Expr *expr = nullptr; - if (isa(D)) { -#if LLVM_VERSION_MAJOR >= 12 - expr = - ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, - DR, nullptr, rval, FPOptionsOverride()); -#else - expr = ImplicitCastExpr::Create( - AST, FT, CastKind::CK_FunctionToPointerDecay, DR, nullptr, rval); -#endif - } else { - expr = - UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval, - clang::ExprObjectKind ::OK_Ordinary, loc, - /*canoverflow*/ false, FPOptionsOverride()); - } - - if (expr->isValueDependent()) { - unsigned ID = S.getDiagnostics().getCustomDiagID( - DiagnosticsEngine::Error, "use of attribute 'enzyme_inactive' " - "in a templated context not yet supported"); - S.Diag(Attr.getLoc(), ID); - return AttributeNotApplied; - } - V->setInit(expr); - S.MarkVariableReferenced(loc, V); - S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); + D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_inactive", nullptr, 0, + Attr.getRange())); return AttributeApplied; } }; @@ -522,71 +375,8 @@ struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo { S.Diag(Attr.getLoc(), ID); return AttributeNotApplied; } - - auto &AST = S.getASTContext(); - DeclContext *declCtx = D->getDeclContext(); - for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { - if (tmpCtx->isRecord()) { - declCtx = tmpCtx->getParent(); - } - } - auto loc = D->getLocation(); - RecordDecl *RD; - if (S.getLangOpts().CPlusPlus) - RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - else - RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - RD->setAnonymousStructOrUnion(true); - RD->setImplicit(); - RD->startDefinition(); - auto T = isa(D) ? cast(D)->getType() - : cast(D)->getType(); - auto Name = isa(D) ? cast(D)->getNameAsString() - : cast(D)->getNameAsString(); - auto FT = AST.getPointerType(T); - auto &Id = AST.Idents.get( - (StringRef("__enzyme_nofree") + "_autoreg_" + Name).str()); - auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None); - V->setStorageClass(SC_PrivateExtern); - V->addAttr(clang::UsedAttr::CreateImplicit(AST)); - TemplateArgumentListInfo *TemplateArgs = nullptr; - auto DR = DeclRefExpr::Create( - AST, NestedNameSpecifierLoc(), loc, cast(D), false, loc, T, - ExprValueKind::VK_LValue, cast(D), TemplateArgs); -#if LLVM_VERSION_MAJOR >= 13 - auto rval = ExprValueKind::VK_PRValue; -#else - auto rval = ExprValueKind::VK_RValue; -#endif - Expr *expr = nullptr; - if (isa(D)) { -#if LLVM_VERSION_MAJOR >= 12 - expr = - ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, - DR, nullptr, rval, FPOptionsOverride()); -#else - expr = ImplicitCastExpr::Create( - AST, FT, CastKind::CK_FunctionToPointerDecay, DR, nullptr, rval); -#endif - } else { - expr = - UnaryOperator::Create(AST, DR, UnaryOperatorKind::UO_AddrOf, FT, rval, - clang::ExprObjectKind ::OK_Ordinary, loc, - /*canoverflow*/ false, FPOptionsOverride()); - } - - if (expr->isValueDependent()) { - unsigned ID = S.getDiagnostics().getCustomDiagID( - DiagnosticsEngine::Error, "use of attribute 'enzyme_nofree' " - "in a templated context not yet supported"); - S.Diag(Attr.getLoc(), ID); - return AttributeNotApplied; - } - V->setInit(expr); - S.MarkVariableReferenced(loc, V); - S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); + D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_nofree", nullptr, 0, + Attr.getRange())); return AttributeApplied; } }; @@ -631,65 +421,10 @@ struct EnzymeSparseAccumulateAttrInfo : public ParsedAttrInfo { S.Diag(Attr.getLoc(), ID); return AttributeNotApplied; } - - auto &AST = S.getASTContext(); - DeclContext *declCtx = D->getDeclContext(); - for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { - if (tmpCtx->isRecord()) { - declCtx = tmpCtx->getParent(); - } - } - auto loc = D->getLocation(); - RecordDecl *RD; - if (S.getLangOpts().CPlusPlus) - RD = CXXRecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - else - RD = RecordDecl::Create(AST, StructKind, declCtx, loc, loc, - nullptr); // rId); - RD->setAnonymousStructOrUnion(true); - RD->setImplicit(); - RD->startDefinition(); - auto T = cast(D)->getType(); - auto Name = cast(D)->getNameAsString(); - auto FT = AST.getPointerType(T); - auto &Id = AST.Idents.get( - (StringRef("__enzyme_sparse_accumulate") + "_autoreg_" + Name).str()); - auto V = VarDecl::Create(AST, declCtx, loc, loc, &Id, FT, nullptr, SC_None); - V->setStorageClass(SC_PrivateExtern); - V->addAttr(clang::UsedAttr::CreateImplicit(AST)); - TemplateArgumentListInfo *TemplateArgs = nullptr; - auto DR = DeclRefExpr::Create( - AST, NestedNameSpecifierLoc(), loc, cast(D), false, loc, T, - ExprValueKind::VK_LValue, cast(D), TemplateArgs); -#if LLVM_VERSION_MAJOR >= 13 - auto rval = ExprValueKind::VK_PRValue; -#else - auto rval = ExprValueKind::VK_RValue; -#endif - Expr *expr = nullptr; -#if LLVM_VERSION_MAJOR >= 12 - expr = - ImplicitCastExpr::Create(AST, FT, CastKind::CK_FunctionToPointerDecay, - DR, nullptr, rval, FPOptionsOverride()); -#else - expr = ImplicitCastExpr::Create( - AST, FT, CastKind::CK_FunctionToPointerDecay, DR, nullptr, rval); -#endif - - if (expr->isValueDependent()) { - unsigned ID = S.getDiagnostics().getCustomDiagID( - DiagnosticsEngine::Error, - "use of attribute 'enzyme_sparse_accumulate' " - "in a templated context not yet supported"); - S.Diag(Attr.getLoc(), ID); - return AttributeNotApplied; - } - V->setInit(expr); - S.MarkVariableReferenced(loc, V); - S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); + D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_sparse_accumulate", + nullptr, 0, Attr.getRange())); return AttributeApplied; - } + }; }; static ParsedAttrInfoRegistry::Add diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index 84b8b5b9540a..1f13aefcd8fa 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -368,6 +368,106 @@ bool preserveNVVM(bool Begin, Function &F) { "__enzyme_register_derivative"; constexpr static const char splitderivative_handler_name[] = "__enzyme_register_splitderivative"; + + if (Begin) + if (GlobalVariable *GA = + F.getParent()->getGlobalVariable("llvm.global.annotations")) { + if (GA->hasInitializer()) { + auto AOp = GA->getInitializer(); + // all metadata are stored in an array of struct of metadata + if (ConstantArray *CA = dyn_cast(AOp)) { + // so iterate over the operands + SmallVector replacements; + for (Value *CAOp : CA->operands()) { + // get the struct, which holds a pointer to the annotated function + // as first field, and the annotation as second field + if (ConstantStruct *CS = dyn_cast(CAOp)) { + if (CS->getNumOperands() >= 2) { + // the second field is a pointer to a global constant Array that + // holds the string + if (GlobalVariable *GAnn = dyn_cast( + CS->getOperand(1)->getOperand(0))) { + if (ConstantDataArray *A = + dyn_cast(GAnn->getOperand(0))) { + // we have the annotation! Check it's an epona annotation + // and process + StringRef AS = A->getAsCString(); + + Constant *Val = + cast(CS->getOperand(0)->getOperand(0)); + while (auto CE = dyn_cast(Val)) + Val = CE->getOperand(0); + + Function *Func = dyn_cast(Val); + GlobalVariable *Glob = dyn_cast(Val); + + if (AS == "enzyme_inactive" && Func && Func == &F) { + Func->addAttribute(AttributeList::FunctionIndex, + Attribute::get(Func->getContext(), + "enzyme_inactive")); + changed = true; + changed |= preserveLinkage(Begin, F); + replacements.push_back( + Constant::getNullValue(CAOp->getType())); + continue; + } + + if (AS == "enzyme_inactive" && Glob) { + Glob->setMetadata("enzyme_inactive", + MDNode::get(Glob->getContext(), {})); + changed = true; + replacements.push_back( + Constant::getNullValue(CAOp->getType())); + continue; + } + + if (AS == "enzyme_nofree" && Func && Func == &F) { + Func->addAttribute(AttributeList::FunctionIndex, + Attribute::get(Func->getContext(), + Attribute::NoFree)); + changed = true; + changed |= preserveLinkage(Begin, F); + replacements.push_back( + Constant::getNullValue(CAOp->getType())); + continue; + } + + if (startsWith(AS, "enzyme_function_like") && Func && + Func == &F) { + auto val = AS.substr(1 + AS.find('=')); + Func->addAttribute(AttributeList::FunctionIndex, + Attribute::get(Func->getContext(), + "enzyme_math", val)); + changed = true; + changed |= preserveLinkage(Begin, F); + replacements.push_back( + Constant::getNullValue(CAOp->getType())); + continue; + } + + if (AS == "enzyme_sparse_accumulate" && Func && + Func == &F) { + Func->addAttribute( + AttributeList::FunctionIndex, + Attribute::get(Func->getContext(), + "enzyme_sparse_accumulate")); + changed = true; + changed |= preserveLinkage(Begin, F); + replacements.push_back( + Constant::getNullValue(CAOp->getType())); + continue; + } + } + } + } + } + replacements.push_back(cast(CAOp)); + } + GA->setInitializer(ConstantArray::get(CA->getType(), replacements)); + } + } + } + for (GlobalVariable &g : F.getParent()->globals()) { if (g.getName().contains(gradient_handler_name) || g.getName().contains(derivative_handler_name) ||