From ff15cd8eec441dfd6ea563bc4f4eb7a2bd1774db Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 30 Jan 2024 13:58:10 -0500 Subject: [PATCH 001/106] [MLIR][ActivityAnalysis] create activity interface (#1648) * [MLIR][ActivityAnalysis] create activity interface * wip * fixup --- enzyme/BUILD | 32 +++- .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 164 ++++++------------ .../Analysis/DataFlowActivityAnalysis.cpp | 19 ++ .../MLIR/Implementations/ArithDerivatives.td | 25 +-- .../MLIR/Implementations/CMakeLists.txt | 11 ++ enzyme/Enzyme/MLIR/Implementations/Common.td | 30 ++++ .../CoreDialectsAutoDiffImplementations.h | 1 + .../LLVMAutoDiffOpInterfaceImpl.cpp | 18 ++ .../MLIR/Implementations/LLVMDerivatives.td | 17 ++ .../NVVMAutoDiffOpInterfaceImpl.cpp | 34 ++++ .../MLIR/Implementations/NVVMDerivatives.td | 4 + .../SCFAutoDiffOpInterfaceImpl.cpp | 16 ++ .../MLIR/Interfaces/AutoDiffOpInterface.td | 21 +++ enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 19 ++ 15 files changed, 272 insertions(+), 140 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/Common.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td diff --git a/enzyme/BUILD b/enzyme/BUILD index c9ae1c3cdb71..8954d5433e0a 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -392,7 +392,35 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"], + td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "llvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/LLVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "nvvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/NVVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], deps = [ ":enzyme-tblgen", ], @@ -420,6 +448,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":arith-derivatives", + ":llvm-derivatives", + ":nvvm-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index c1a471277915..bb8c8f2ccc9c 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -3,7 +3,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" @@ -19,6 +18,8 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/Support/ModRef.h" +#include "Interfaces/AutoDiffOpInterface.h" + const char *KnownInactiveFunctionsStartingWith[] = { "f90io", "$ss5print", @@ -467,9 +468,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (isa(I)) return true; + if (auto ifaceOp = dyn_cast(I)) { + if (ifaceOp.isInactive()) { + return true; + } + } + // Branch, unreachable, and previously computed constants are inactive - if (isa(I) /*|| isa(I)*/ || - ConstantOperations.contains(I)) { + if (/*|| isa(I)*/ ConstantOperations.contains(I)) { return true; } @@ -592,12 +598,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // *I // << "\n"; - if (isa(I)) { - InsertConstantOperation(TR, I); - } - // if (auto II = dyn_cast(I)) { // switch (II->getIntrinsicID()) { // case Intrinsic::nvvm_barrier0: @@ -1121,13 +1121,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (Val.getType().isa()) return true; - // All function pointers are considered active in case an augmented primal - // or reverse is needed - if (Val.getDefiningOp() && - isa(Val.getDefiningOp())) { - return false; - } - /// If we've already shown this value to be inactive if (ConstantValues.find(Val) != ConstantValues.end()) { return true; @@ -1142,44 +1135,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (matchPattern(Val, m_Constant())) return true; - // if (auto CD = dyn_cast(Val)) { - // // inductively assume inactive - // ConstantValues.insert(CD); - // for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { - // if (!isConstantValue(TR, CD->getElementAsConstant(i))) { - // ConstantValues.erase(CD); - // ActiveValues.insert(CD); - // return false; - // } - // } - // return true; - // } - // if (auto CD = dyn_cast(Val)) { - // // inductively assume inactive - // ConstantValues.insert(CD); - // for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - // if (!isConstantValue(TR, CD->getOperand(i))) { - // ConstantValues.erase(CD); - // ActiveValues.insert(CD); - // return false; - // } - // } - // return true; - // } - if (Operation *definingOp = Val.getDefiningOp()) { - // Undef and non-global constants are inactive. - if (isa(definingOp)) { - return true; - } - - // Ops derived from intrinsics. - // NOTE: this was written with the assumption that Value is-a Operation, - // which is not the case in MLIR. - if (isa(definingOp)) { - return true; + if (auto ifaceOp = dyn_cast(definingOp)) { + if (ifaceOp.isInactive()) { + return true; + } } } @@ -1494,6 +1454,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } } + if (auto op = TmpOrig.getDefiningOp()) + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { + InsertConstantValue(TR, Val); + if (TmpOrig != Val) { + InsertConstantValue(TR, TmpOrig); + } + return true; + } + } + UpHypothesis = std::shared_ptr( new mlir::enzyme::ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantValues.insert(Val); @@ -1828,16 +1799,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (notForAnalysis.count(op->getBlock())) return false; - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("exit") || - iasm.getAsmString().contains("cpuid")) - return false; - } - if (isa(op)) { - return true; - } + if (auto op = TmpOrig.getDefiningOp()) + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { + return false; + } + } // If this is a malloc or free, this doesn't impact the activity if (auto CI = dyn_cast(op)) { @@ -2537,21 +2504,16 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return false; } - // if (EnzymePrintActivity) - // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << - // "\n"; - - // cpuid is explicitly an inactive instruction - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("cpuid")) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from known cpuid instruction - // " - // << *inst << "\n"; + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { return true; } } + // if (EnzymePrintActivity) + // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << + // "\n"; + if (auto store = dyn_cast(op)) { if (isConstantValue(TR, store.getValue()) || isConstantValue(TR, store.getAddr())) { @@ -2643,15 +2605,6 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return true; } } - // Intrinsics known always to be inactive - if (isa(op)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - // << *inst << "\n"; - return true; - } if (auto gep = dyn_cast(op)) { // A gep's only args that could make it active is the pointer operand @@ -2731,13 +2684,7 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return false; } - if (isa(op)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-fpcst:" << - // *inst - // << "\n"; - return true; - } else { + { bool seenuse = false; //! TODO does not consider reading from global memory that is active and not //! an argument @@ -2871,6 +2818,13 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // } } + if (UA != UseActivity::AllStores) { + if (auto ifaceOp = dyn_cast(a)) { + if (ifaceOp.isArgInactive(parent)) + return true; + } + } + // if (EnzymePrintActivity) // llvm::errs() << " considering use of " << *val << " - " << *a // << "\n"; @@ -3078,14 +3032,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( continue; } - if (isa( - a)) { - // if (EnzymePrintActivity) - // llvm::errs() << "found constant(" << (int)directions - // << ") si-fp use:" << *val << " user " << *a << "\n"; - continue; - } - // // TODO: this should not happen in valid MLIR... // @@ -3367,10 +3313,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( LLVM::SExtOp, LLVM::ZExtOp, LLVM::TruncOp, - LLVM::SIToFPOp, - LLVM::UIToFPOp, - LLVM::FPToSIOp, - LLVM::FPToUIOp, LLVM::FPExtOp, LLVM::FPTruncOp // clang-format on @@ -3451,6 +3393,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( continue; } + if (auto ifaceOp = dyn_cast(a)) { + if (ifaceOp.isArgInactive(val)) + return true; + } + if (isa(a)) { if (ActiveReturns == DIFFE_TYPE::CONSTANT) continue; @@ -3509,19 +3456,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // TODO: in MLIR, users are always operations // if (Operation *inst = a) { - auto mayWriteToMemory = [](Operation *op) { - auto iface = dyn_cast(op); - if (!iface) - return true; - - SmallVector effects; - iface.getEffects(effects); - return llvm::any_of( - effects, [](const MemoryEffects::EffectInstance &effect) { - return isa(effect.getEffect()); - }); - }; - if (!mayWriteToMemory(inst) /*|| (isa(inst) && AA.onlyReadsMemory(cast(inst)))*/) { // // if not written to memory and returning a known constant, this diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 608bf75f40a5..4f922d6eb5d5 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -43,6 +43,8 @@ #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" +#include "Interfaces/AutoDiffOpInterface.h" + using namespace mlir; using namespace mlir::dataflow; using enzyme::AliasClassLattice; @@ -508,6 +510,15 @@ class DenseForwardActivityAnalysis ForwardMemoryActivity *after) override { join(after, before); ChangeResult result = ChangeResult::NoChange; + + // TODO If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // propagateIfChanged(after, result); + // return; + // } + // } + auto memory = dyn_cast(op); // If we can't reason about the memory effects, then conservatively assume // we can't deduce anything about activity via side-effects. @@ -657,6 +668,14 @@ class DenseBackwardActivityAnalysis void visitOperation(Operation *op, const BackwardMemoryActivity &after, BackwardMemoryActivity *before) override { + + // TODO: If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // return; + // } + // } + // Initialize the return activity of arguments. if (op->hasTrait() && op->getParentOp() == parentOp) { for (const auto &[arg, argActivity] : diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index bb713ef61799..fb7f113f16fe 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -1,25 +1,5 @@ -class MLIRDerivative resultOps> { - string dialect = dialect_; - string opName = opName_; - dag PatternToMatch = patternToMatch; - list ArgDerivatives = resultOps; -} +include "Common.td" -class Operation { - bit usesPrimal = usesPrimal_; - bit usesShadow = usesShadow_; - bit usesCustom = usesCustom_; -} - -class DiffeRetIndex indices_> { - list indices = indices_; -} -def DiffeRet : DiffeRetIndex<[-1]>; - -class Inst : Operation { - string name = mnemonic; - string dialect = dialect_; -} class ArithInst : Inst; def AddF : ArithInst<"arith::AddFOp">; @@ -32,9 +12,6 @@ def RemF : ArithInst<"arith::RemFOp">; def CheckedMulF : ArithInst<"arith::MulFOp">; def CheckedDivF : ArithInst<"arith::DivFOp">; -def Op { -} - def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y), [ (DiffeRet), diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index a41ee2133c68..521ba76c22bc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -3,9 +3,18 @@ set(LLVM_TARGET_DEFINITIONS ArithDerivatives.td) enzyme_tablegen(ArithDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(ArithDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS LLVMDerivatives.td) +enzyme_tablegen(LLVMDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(LLVMDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS NVVMDerivatives.td) +enzyme_tablegen(NVVMDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(NVVMDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations ArithAutoDiffOpInterfaceImpl.cpp LLVMAutoDiffOpInterfaceImpl.cpp + NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -14,6 +23,8 @@ add_mlir_library(MLIREnzymeImplementations DEPENDS MLIRAutoDiffOpInterfaceIncGen ArithDerivativesIncGen + LLVMDerivativesIncGen + NVVMDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td new file mode 100644 index 000000000000..3909405320d1 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -0,0 +1,30 @@ +class InactiveOp { + string dialect = dialect_; + string opName = opName_; +} + +class MLIRDerivative resultOps> { + string dialect = dialect_; + string opName = opName_; + dag PatternToMatch = patternToMatch; + list ArgDerivatives = resultOps; +} + +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} + +class DiffeRetIndex indices_> { + list indices = indices_; +} +def DiffeRet : DiffeRetIndex<[-1]>; + +class Inst : Operation { + string name = mnemonic; + string dialect = dialect_; +} + +def Op { +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 669b028998c6..56af04b30133 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -19,6 +19,7 @@ namespace enzyme { void registerArithDialectAutoDiffInterface(DialectRegistry ®istry); void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry); void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry); +void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 3c60fd421c7a..079dd1cb64e9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -23,6 +23,23 @@ using namespace mlir; using namespace mlir::enzyme; namespace { +#include "Implementations/LLVMDerivatives.inc" +} // namespace + +namespace { +struct InlineAsmActivityInterface + : public ActivityOpInterface::ExternalModel { + bool isInactive(Operation *op) const { + auto asmOp = cast(op); + auto str = asmOp.getAsmString(); + return str.contains("cpuid") || str.contains("exit"); + } + bool isArgInactive(Operation *op, mlir::Value) const { + return isInactive(op); + } +}; + struct LoadOpInterface : public AutoDiffOpInterface::ExternalModel { LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, @@ -102,5 +119,6 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface( LLVM::StoreOp::attachInterface(*context); LLVM::AllocaOp::attachInterface(*context); LLVM::LLVMPointerType::attachInterface(*context); + registerInterfaces(context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td new file mode 100644 index 000000000000..9e5f28e41665 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -0,0 +1,17 @@ +include "Common.td" + +def : InactiveOp<"LLVM", "SIToFPOp">; +def : InactiveOp<"LLVM", "UIToFPOp">; +def : InactiveOp<"LLVM", "FPToSIOp">; +def : InactiveOp<"LLVM", "FPToUIOp">; +def : InactiveOp<"LLVM", "AssumeOp">; +def : InactiveOp<"LLVM", "StackSaveOp">; +def : InactiveOp<"LLVM", "StackRestoreOp">; +def : InactiveOp<"LLVM", "LifetimeStartOp">; +def : InactiveOp<"LLVM", "LifetimeEndOp">; +def : InactiveOp<"LLVM", "Prefetch">; +def : InactiveOp<"LLVM", "MemsetOp">; + +def : InactiveOp<"LLVM", "UndefOp">; +def : InactiveOp<"LLVM", "ConstantOp">; +def : InactiveOp<"LLVM", "UnreachableOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..4d8116ce011b --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,34 @@ +//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/NVVMDerivatives.inc" +} // namespace + +void mlir::enzyme::registerNVVMDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, NVVM::NVVMDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td new file mode 100644 index 000000000000..f34dfb564cbc --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td @@ -0,0 +1,4 @@ +include "Common.td" + +// TODO in reverse replicate in reverse pass +def : InactiveOp<"NVVM", "Barrier0Op">; diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 52f48ce7e2d1..71d3db0c2905 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -32,6 +32,8 @@ struct ForOpInterface MGradientUtils *gutils) const { auto forOp = cast(op); auto nFor = cast(gutils->getNewFromOriginal(op)); + // Get a list of all the return types, which is the original return types + // alongside any shadow return types SmallVector nTypes; for (auto r : forOp->getResults()) { // TODO only if used @@ -43,6 +45,8 @@ struct ForOpInterface nTypes.push_back(adTypeIface.getShadowType()); } } + + // Get a list of all args, which is original args, and any shadows SmallVector nArgs; for (const auto &[initVal, iterArg] : llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs())) { @@ -51,12 +55,16 @@ struct ForOpInterface if (!gutils->isConstantValue(iterArg)) nArgs.push_back(gutils->invertPointerM(initVal, builder)); } + + // Create the new modified for loop auto repFor = builder.create( forOp.getLoc(), gutils->getNewFromOriginal(forOp.getLowerBound()), gutils->getNewFromOriginal(forOp.getUpperBound()), gutils->getNewFromOriginal(forOp.getStep()), nArgs); repFor.getRegion().takeBody(nFor.getRegion()); + // Inject the mapping for the new results into GradientUtil's shadow + // table SmallVector reps; size_t idx = 0; for (Value r : forOp.getResults()) { @@ -72,13 +80,19 @@ struct ForOpInterface idx++; } } + + // Replace all uses of original results nFor.replaceAllUsesWith(reps); gutils->erase(nFor); + + // differentiate body for (Operation &o : llvm::make_early_inc_range(forOp.getBody()->without_terminator())) { if (failed(gutils->visitChild(&o))) return failure(); } + + // Fix terminator (yield) operations Operation *oldYield = repFor.getBody()->getTerminator(); builder.setInsertionPointToEnd(repFor.getBody()); SmallVector nYields; @@ -93,6 +107,8 @@ struct ForOpInterface Operation *newYield = builder.clone(*oldYield); newYield->setOperands(nYields); gutils->erase(oldYield); + + // Done return success(); } }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 61099c316cd2..f45b641f9cb8 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -76,4 +76,25 @@ def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { ]; } +def ActivityOpInterface + : OpInterface<"ActivityOpInterface"> { + let cppNamespace = "::mlir::enzyme"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + }], + /*retTy=*/"bool", + /*methodName=*/"isInactive" + >, + InterfaceMethod< + /*desc=*/[{ + }], + /*retTy=*/"bool", + /*methodName=*/"isArgInactive", + /*args=*/(ins "::mlir::Value":$val) + > + ]; +} + #endif // ENZYME_MLIR_INTERFACES_AUTODIFFOPINTERFACES diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 589ffa610a28..b6bb33c51df9 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -100,6 +100,7 @@ int main(int argc, char **argv) { enzyme::registerArithDialectAutoDiffInterface(registry); enzyme::registerBuiltinDialectAutoDiffInterface(registry); enzyme::registerLLVMDialectAutoDiffInterface(registry); + enzyme::registerNVVMDialectAutoDiffInterface(registry); enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 11384a2284e8..1f4a3dc1b892 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1875,6 +1875,19 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } if (intrinsic == MLIRDerivatives) { + const auto &actpatterns = + recordKeeper.getAllDerivedDefinitions("InactiveOp"); + for (auto &pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "Activity : \n"; + os << " public ActivityOpInterface::ExternalModel<" + << opName << "Activity, " << dialect << "::" << opName << "> {\n"; + os << " bool isInactive(mlir::Operation*) const { return true; }\n"; + os << " bool isArgInactive(mlir::Operation*, mlir::Value) const { " + "return true; }\n"; + os << "};\n"; + } os << "void registerInterfaces(MLIRContext* context) {\n"; for (Record *pattern : patterns) { auto opName = pattern->getValueAsString("opName"); @@ -1884,6 +1897,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "RevDerivative>(*context);\n"; } + for (Record *pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "Activity>(*context);\n"; + } os << "}\n"; } } From 53a0bd8aef1dd49808420f19ef16847ad2d0ef6f Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 31 Jan 2024 13:10:00 -0500 Subject: [PATCH 002/106] [WIP] Simplify MLIR (#1646) --- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 4 +- .../MLIR/Interfaces/GradientUtilsReverse.cpp | 64 ++----------------- .../MLIR/Interfaces/GradientUtilsReverse.h | 29 +-------- 3 files changed, 10 insertions(+), 87 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 9b9509b3ec22..094228232346 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -74,14 +74,14 @@ class MDiffeGradientUtils : public MGradientUtils { FunctionOpInterface oldFunc_, MTypeAnalysis &TA, MTypeResults TR, IRMapping &invertedPointers_, const SmallPtrSetImpl &constantvalues_, - const SmallPtrSetImpl &returnvals_, + const SmallPtrSetImpl &activevals_, DIFFE_TYPE ActiveReturn, ArrayRef constant_values, IRMapping &origToNew_, std::map &origToNewOps_, DerivativeMode mode, unsigned width, bool omp) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, - constantvalues_, returnvals_, ActiveReturn, + constantvalues_, activevals_, ActiveReturn, constant_values, origToNew_, origToNewOps_, mode, width, omp) {} diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 02a3d3d956d6..65d0349f1476 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -36,10 +36,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_) - : newFunc(newFunc_), oldFunc(oldFunc_), Logic(Logic), mode(mode_), - originalToNewFn(originalToNewFn_), - originalToNewFnOps(originalToNewFnOps_), TA(TA_), width(width), - ArgDiffeTypes(ArgDiffeTypes_), symbolTable(symbolTable_) { + : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, + invertedPointers_, constantvalues_, activevals_, + ReturnActivity, ArgDiffeTypes_, originalToNewFn_, + originalToNewFnOps_, mode_, width, /*omp*/ false), + symbolTable(symbolTable_) { initInitializationBlock(invertedPointers_, ArgDiffeTypes_); } @@ -136,43 +137,6 @@ Value mlir::enzyme::MGradientUtilsReverse::insertInitShadowedGradient( return gradient; } -Value mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - const mlir::Value originst) const { - if (!originalToNewFn.contains(originst)) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - llvm::errs() << originst << "\n"; - llvm_unreachable("Could not get new val from original"); - } - return originalToNewFn.lookupOrNull(originst); -} - -Block *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - mlir::Block *originst) const { - if (!originalToNewFn.contains(originst)) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - llvm::errs() << originst << "\n"; - llvm_unreachable("Could not get new blk from original"); - } - return originalToNewFn.lookupOrNull(originst); -} - -Operation *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - Operation *originst) const { - auto found = originalToNewFnOps.find(originst); - if (found == originalToNewFnOps.end()) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - for (auto &pair : originalToNewFnOps) { - llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n"; - } - llvm::errs() << originst << " - " << *originst << "\n"; - llvm_unreachable("Could not get new op from original"); - } - return found->second; -} - Operation * mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, Operation *op) { @@ -182,24 +146,6 @@ mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, return B.clone(*op, map); } -bool mlir::enzyme::MGradientUtilsReverse::isConstantInstruction( - Operation *op) const { - return false; -} - -bool mlir::enzyme::MGradientUtilsReverse::isConstantValue(Value v) const { - if (isa(v.getType())) - return true; - if (isa(v.getType())) - return true; - - if (matchPattern(v, m_Constant())) - return true; - - // TODO - return false; -} - bool mlir::enzyme::MGradientUtilsReverse::requiresShadow(Type t) { if (auto iface = dyn_cast(t)) { return iface.requiresShadow(); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index 474badb034c9..3d5d4d147269 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -15,10 +15,12 @@ #include +#include "GradientUtils.h" + namespace mlir { namespace enzyme { -class MGradientUtilsReverse { +class MGradientUtilsReverse : public MDiffeGradientUtils { public: MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, @@ -32,13 +34,6 @@ class MGradientUtilsReverse { DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_); - // From CacheUtility - FunctionOpInterface newFunc; - FunctionOpInterface oldFunc; - - MEnzymeLogic &Logic; - bool AtomicAdd; - DerivativeMode mode; IRMapping invertedPointersGlobal; IRMapping invertedPointersShadow; IRMapping shadowValues; @@ -47,26 +42,8 @@ class MGradientUtilsReverse { IRMapping mapReverseModeBlocks; DenseMap>> mapBlockArguments; - IRMapping originalToNewFn; - std::map originalToNewFnOps; - - MTypeAnalysis &TA; - - unsigned width; - ArrayRef ArgDiffeTypes; - SymbolTableCollection &symbolTable; - mlir::Value getNewFromOriginal(const mlir::Value originst) const; - mlir::Block *getNewFromOriginal(mlir::Block *originst) const; - Operation *getNewFromOriginal(Operation *originst) const; - - void erase(Operation *op) { op->erase(); } - void eraseIfUnused(Operation *op, bool erase = true, bool check = true) { - // TODO - } - bool isConstantValue(mlir::Value v) const; - bool isConstantInstruction(mlir::Operation *v) const; bool hasInvertPointer(mlir::Value v); mlir::Value invertPointerM(mlir::Value v, OpBuilder &builder); mlir::Value diffe(mlir::Value v, OpBuilder &builder); From 724756e5d770f3cb7cee6dc17ea292443a55eb05 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 31 Jan 2024 19:55:18 +0100 Subject: [PATCH 003/106] Reimplement MLIR tangent for scf.for generically (#1649) * Reimplement MLIR tangent for scf.for generically Introduce an interface for ops with region control flow and use that interface inside the tangent implementation for scf.for. This interface can also work for other kinds of control flow ops as long as they don't transfer control flow between regions (e.g. like scf.while does). Further cleanup is necessary to clearly delegate implementation of AutoDiffOpInterface to this new interface for any compatible op via traits. * Generalize * fixup * Decrease the amount of tablegen'd code * Add some documentation and use better types * add forward diff support for affine.for and affine.if * bazel * add forward diff rule for scf.while Fix a minor bug discovered in the process: the interface was designed for constant folding and expects a list of null attributes of the appropriate size... * add executeop * add cf * fixup --------- Co-authored-by: William S. Moses --- enzyme/BUILD | 32 +++ .../AffineAutoDiffOpInterfaceImpl.cpp | 64 +++++ .../MLIR/Implementations/AffineDerivatives.td | 24 ++ .../CFAutoDiffOpInterfaceImpl.cpp | 41 +++ .../MLIR/Implementations/CFDerivatives.td | 4 + .../MLIR/Implementations/CMakeLists.txt | 17 ++ enzyme/Enzyme/MLIR/Implementations/Common.td | 16 ++ .../CoreDialectsAutoDiffImplementations.cpp | 246 ++++++++++++++++++ .../CoreDialectsAutoDiffImplementations.h | 87 +++++++ .../SCFAutoDiffOpInterfaceImpl.cpp | 94 +------ .../MLIR/Implementations/SCFDerivatives.td | 50 ++++ .../MLIR/Interfaces/AutoDiffOpInterface.td | 36 +++ enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 37 +-- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 4 +- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 8 + enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 2 + enzyme/test/MLIR/ForwardMode/affine.mlir | 104 ++++++++ .../branch-self-recursive.mlir | 0 .../test/MLIR/{ => ForwardMode}/branch.mlir | 0 enzyme/test/MLIR/ForwardMode/executeop.mlir | 60 +++++ enzyme/test/MLIR/{ => ForwardMode}/for.mlir | 0 enzyme/test/MLIR/ForwardMode/for2.mlir | 30 +++ enzyme/test/MLIR/ForwardMode/if1.mlir | 41 +++ .../test/MLIR/{ => ForwardMode}/inactive.mlir | 0 .../test/MLIR/{ => ForwardMode}/invalid.mlir | 0 enzyme/test/MLIR/{ => ForwardMode}/llvm.mlir | 0 .../test/MLIR/{ => ForwardMode}/memref.mlir | 0 enzyme/test/MLIR/{ => ForwardMode}/test.mlir | 0 enzyme/test/MLIR/ForwardMode/while.mlir | 44 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 39 +++ 30 files changed, 957 insertions(+), 123 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td create mode 100644 enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td create mode 100644 enzyme/test/MLIR/ForwardMode/affine.mlir rename enzyme/test/MLIR/{ => ForwardMode}/branch-self-recursive.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/branch.mlir (100%) create mode 100644 enzyme/test/MLIR/ForwardMode/executeop.mlir rename enzyme/test/MLIR/{ => ForwardMode}/for.mlir (100%) create mode 100644 enzyme/test/MLIR/ForwardMode/for2.mlir create mode 100644 enzyme/test/MLIR/ForwardMode/if1.mlir rename enzyme/test/MLIR/{ => ForwardMode}/inactive.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/invalid.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/llvm.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/memref.mlir (100%) rename enzyme/test/MLIR/{ => ForwardMode}/test.mlir (100%) create mode 100644 enzyme/test/MLIR/ForwardMode/while.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index 8954d5433e0a..f091b06bd8f5 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -384,6 +384,22 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) +gentbl( + name = "affine-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/AffineDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/AffineDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/AffineDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + gentbl( name = "arith-derivatives", tbl_outs = [( @@ -426,6 +442,20 @@ gentbl( ], ) +gentbl( + name = "scf-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/SCFDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/SCFDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/SCFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -447,9 +477,11 @@ cc_library( includes = ["Enzyme/MLIR", "Enzyme"], visibility = ["//visibility:public"], deps = [ + ":affine-derivatives", ":arith-derivatives", ":llvm-derivatives", ":nvvm-derivatives", + ":scf-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..c27f0d60d129 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,64 @@ +//===- AffineAutoDiffOpInterfaceImpl.cpp - Interface external model -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR Affine dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/IntegerSet.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +affine::AffineForOp +createAffineForWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, TypeRange rettys) { + affine::AffineForOpAdaptor adaptor(remappedOperands, + cast(original)); + auto repFor = builder.create( + original->getLoc(), adaptor.getLowerBoundOperands(), + adaptor.getLowerBoundMap(), adaptor.getUpperBoundOperands(), + adaptor.getUpperBoundMap(), adaptor.getStep().getZExtValue(), + // This dance is necessary because the adaptor accessors are based on the + // internal attribute containing the number of operands associated with + // each named operand group. This attribute is carried over from the + // original operation and does not account for the shadow-related iter + // args. Instead, assume lower/upper bound operands must not have shadows + // since they are integer-typed and take the result of operands as iter + // args. + remappedOperands.drop_front(adaptor.getLowerBoundOperands().size() + + adaptor.getUpperBoundOperands().size())); + return repFor; +} + +affine::AffineIfOp createAffineIfWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, + affine::AffineIfOp original, + ValueRange remappedOperands, + TypeRange rettys) { + affine::AffineIfOpAdaptor adaptor(remappedOperands, original); + return builder.create( + original->getLoc(), rettys, original.getIntegerSet(), + adaptor.getOperands(), !original.getElseRegion().empty()); +} + +#include "Implementations/AffineDerivatives.inc" +} // namespace + +void mlir::enzyme::registerAffineDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, affine::AffineDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td new file mode 100644 index 000000000000..d6e866f3089e --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td @@ -0,0 +1,24 @@ +include "Common.td" + +def : ControlFlowOp<"affine", "AffineForOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return createAffineForWithShadows(op, builder, gutils, original, + remappedOperands, rettys); + } +}]>; + +def : ControlFlowOp<"affine", "AffineIfOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return createAffineIfWithShadows(op, builder, gutils, + cast(original), + remappedOperands, rettys); + } +}]>; + +def : RegionTerminatorOp<"affine", "AffineYieldOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..8f40db9d834d --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,41 @@ +//===- SCFAutoDiffOpInterfaceImpl.cpp - Interface external model ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR SCF dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/EnzymeLogic.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/CFDerivatives.inc" +} // namespace + +void mlir::enzyme::registerCFDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, cf::ControlFlowDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td new file mode 100644 index 000000000000..8b4f41696a9d --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td @@ -0,0 +1,4 @@ +include "Common.td" + +def : BranchOp<"cf", "BranchOp">; +def : BranchOp<"cf", "SwitchOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index 521ba76c22bc..8e62f5849be1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -1,3 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS AffineDerivatives.td) +enzyme_tablegen(AffineDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(AffineDerivativesIncGen) set(LLVM_TARGET_DEFINITIONS ArithDerivatives.td) enzyme_tablegen(ArithDerivatives.inc -gen-mlir-derivatives) @@ -11,20 +14,34 @@ set(LLVM_TARGET_DEFINITIONS NVVMDerivatives.td) enzyme_tablegen(NVVMDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(NVVMDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS SCFDerivatives.td) +enzyme_tablegen(SCFDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(SCFDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS CFDerivatives.td) +enzyme_tablegen(CFDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(CFDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations + AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp + CoreDialectsAutoDiffImplementations.cpp LLVMAutoDiffOpInterfaceImpl.cpp NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp + CFAutoDiffOpInterfaceImpl.cpp DEPENDS MLIRAutoDiffOpInterfaceIncGen + AffineDerivativesIncGen ArithDerivativesIncGen LLVMDerivativesIncGen NVVMDerivativesIncGen + SCFDerivativesIncGen + CFDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 3909405320d1..c09cac32f734 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -3,6 +3,22 @@ class InactiveOp { string opName = opName_; } +class ControlFlowOp { + string dialect = dialect_; + string opName = opName_; + string impl = impl_; +} + +class BranchOp { + string dialect = dialect_; + string opName = opName_; +} + +class RegionTerminatorOp { + string dialect = dialect_; + string opName = opName_; +} + class MLIRDerivative resultOps> { string dialect = dialect_; string opName = opName_; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp new file mode 100644 index 000000000000..2a50d1c481c0 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -0,0 +1,246 @@ +//===- CoreDialectsAutoDiffImplementations.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains common utilities for the external model implementation of +// the automatic differentiation op interfaces for upstream MLIR dialects. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" + +using namespace mlir; +using namespace mlir::enzyme; + +void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, + OpBuilder &builder, + MGradientUtils *gutils) { + auto newInst = gutils->getNewFromOriginal(inst); + + auto binst = cast(inst); + + // TODO generalize to cloneWithNewBlockArgs interface + SmallVector newVals; + + SmallVector segSizes; + // Keep non-differentiated, non-forwarded operands + size_t non_forwarded = 0; + for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { + auto ops = binst.getSuccessorOperands(i).getForwardedOperands(); + if (ops.empty()) + continue; + non_forwarded = ops.getBeginOperandIndex(); + break; + } + + for (size_t i = 0; i < non_forwarded; i++) + newVals.push_back(gutils->getNewFromOriginal(binst->getOperand(i))); + + segSizes.push_back(newVals.size()); + for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { + size_t cur = newVals.size(); + auto ops = binst.getSuccessorOperands(i).getForwardedOperands(); + for (auto &&[idx, op] : llvm::enumerate(ops)) { + auto arg = + *binst.getSuccessorBlockArgument(ops.getBeginOperandIndex() + idx); + newVals.push_back(gutils->getNewFromOriginal(op)); + if (!gutils->isConstantValue(arg)) { + if (!gutils->isConstantValue(op)) { + newVals.push_back(gutils->invertPointerM(op, builder)); + } else { + Type retTy = + arg.getType().cast().getShadowType(); + auto toret = retTy.cast().createNullValue( + builder, op.getLoc()); + newVals.push_back(toret); + } + } + } + cur = newVals.size() - cur; + segSizes.push_back(cur); + } + + SmallVector attrs(newInst->getAttrs()); + bool has_cases = false; + for (auto &attr : attrs) { + if (attr.getName() == "case_operand_segments") { + has_cases = true; + } + } + for (auto &attr : attrs) { + if (attr.getName() == "operandSegmentSizes") { + if (!has_cases) { + attr.setValue(builder.getDenseI32ArrayAttr(segSizes)); + } else { + SmallVector segSlices2(segSizes.begin(), segSizes.begin() + 2); + segSlices2.push_back(0); + for (size_t i = 2; i < segSizes.size(); i++) + segSlices2[2] += segSizes[i]; + attr.setValue(builder.getDenseI32ArrayAttr(segSlices2)); + } + } + if (attr.getName() == "case_operand_segments") { + SmallVector segSlices2(segSizes.begin() + 2, segSizes.end()); + attr.setValue(builder.getDenseI32ArrayAttr(segSlices2)); + } + } + + gutils->getNewFromOriginal(inst->getBlock()) + ->push_back( + newInst->create(newInst->getLoc(), newInst->getName(), TypeRange(), + newVals, attrs, OpaqueProperties(nullptr), + newInst->getSuccessors(), newInst->getNumRegions())); + gutils->erase(newInst); + return; +} + +void mlir::enzyme::detail::regionTerminatorForwardHandler( + Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { + auto termIface = cast(origTerminator); + auto parentOp = termIface->getParentOp(); + + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + llvm::SmallDenseSet operandsToShadow; + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[i, target] : llvm::enumerate(targetValues)) { + if (!gutils->isConstantValue(target)) + operandsToShadow.insert(operandRange.getBeginOperandIndex() + i); + } + } + SmallVector newOperands; + newOperands.reserve(termIface->getNumOperands() + operandsToShadow.size()); + for (OpOperand &operand : termIface->getOpOperands()) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + if (operandsToShadow.contains(operand.getOperandNumber())) + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + + // Assuming shadows following the originals are fine. + // TODO: consider extending to have a ShadowableTerminatorOpInterface + Operation *replTerminator = gutils->getNewFromOriginal(origTerminator); + Operation *newTerminator = builder.clone(*replTerminator); + newTerminator->setOperands(newOperands); + gutils->erase(replTerminator); +} + +LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils) { + // For all active results, add shadow types. + // For now, assuming all results are relevant. + Operation *newOp = gutils->getNewFromOriginal(op); + SmallVector newOpResultTypes; + newOpResultTypes.reserve(op->getNumResults() * 2); + for (Value result : op->getResults()) { + // TODO only if used (can we DCE the primal after having done the + // derivative). + newOpResultTypes.push_back(result.getType()); + if (gutils->isConstantValue(result)) + continue; + auto typeIface = dyn_cast(result.getType()); + if (!typeIface) + return failure(); + newOpResultTypes.push_back(typeIface.getShadowType()); + } + + // For all operands that are forwarded to the body, if they are active, also + // add the shadow as operand. + auto regionBranchOp = dyn_cast(op); + if (!regionBranchOp) + return failure(); + + SmallVector successors; + // TODO: we may need to record, for every successor, which of its inputs + // need a shadow to recreate the body correctly. + llvm::SmallDenseSet operandPositionsToShadow; + regionBranchOp.getEntrySuccessorRegions( + SmallVector(op->getNumOperands(), Attribute()), successors); + for (const RegionSuccessor &successor : successors) { + if (!successor.isParent() && successor.getSuccessor()->empty()) + continue; + + OperandRange operandRange = + regionBranchOp.getEntrySuccessorOperands(successor); + + // Need to know which of the arguments are being forwarded to from + // operands. + for (auto &&[i, regionValue, operand] : + llvm::enumerate(successor.getSuccessorInputs(), operandRange)) { + if (gutils->isConstantValue(regionValue)) + continue; + operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); + } + } + SmallVector newOperands; + newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size()); + for (OpOperand &operand : op->getOpOperands()) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + if (operandPositionsToShadow.contains(operand.getOperandNumber())) + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + // We are assuming the op can forward additional operands, listed + // immediately after the original operands, to the same regions. + // ^^ + // Our interface guarantees this. + // We also assume that the region-holding op returns all of the values + // yielded by terminators, and only those values. + + auto iface = dyn_cast(op); + if (!iface) + return failure(); + Operation *replacement = iface.createWithShadows( + builder, gutils, op, newOperands, newOpResultTypes); + for (auto &&[region, replacementRegion] : + llvm::zip(newOp->getRegions(), replacement->getRegions())) { + replacementRegion.takeBody(region); + } + + // Inject the mapping for the new results into GradientUtil's shadow + // table. + SmallVector reps; + size_t idx = 0; + for (Value r : op->getResults()) { + // TODO only if used + reps.push_back(replacement->getResult(idx)); + idx++; + if (!gutils->isConstantValue(r)) { + auto inverted = gutils->invertedPointers.lookupOrNull(r); + assert(inverted); + gutils->invertedPointers.map(r, replacement->getResult(idx)); + inverted.replaceAllUsesWith(replacement->getResult(idx)); + gutils->erase(inverted.getDefiningOp()); + idx++; + } + } + + // Differentiate body. + for (auto &origRegion : op->getRegions()) { + for (auto &origBlock : origRegion) { + for (Operation &o : origBlock) { + if (failed(gutils->visitChild(&o))) + return failure(); + } + } + } + + // Replace all uses of original results + gutils->replaceOrigOpWith(op, reps); + gutils->erase(newOp); + + return success(); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 56af04b30133..b98a6ef10c39 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -12,16 +12,103 @@ // //===----------------------------------------------------------------------===// +#include "Interfaces/AutoDiffOpInterface.h" +#include "mlir/Support/LogicalResult.h" + namespace mlir { class DialectRegistry; +class Operation; +class OpBuilder; namespace enzyme { +class MGradientUtils; + +namespace detail { +// Non-template implementation of +// AutoDiffUsingControlFlow::createForwardModeTangent. +LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements forward-mode differentiation of branching operations. +// Assumes that successive shadows are legal +void branchingForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements forward-mode differentiation of region-terminator operations. +// Assumes that successive shadows are legal +void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their control flow and differentiating the +// nested operations. +template +class AutoDiffUsingControlFlow + : public AutoDiffOpInterface::ExternalModel, + OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + return controlFlowForwardHandler(op, builder, gutils); + } +}; + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their branching properties. +template +class AutoDiffUsingBranch + : public AutoDiffOpInterface::ExternalModel, + OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + branchingForwardHandler(op, builder, gutils); + return success(); + } +}; + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their region terminator properties. +template +class AutoDiffUsingRegionTerminator + : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingRegionTerminator, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + regionTerminatorForwardHandler(op, builder, gutils); + return success(); + } +}; +} // namespace detail + +// Registers AutoDiffUsingControlFlow for the given op. +template +void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); +} +// Registers AutoDiffUsingBranch for the given op. +template +void registerAutoDiffUsingBranchInterface(MLIRContext &context) { + OpTy::template attachInterface>(context); +} +// Registers AutoDiffUsingRegionTerminator for the given op. +template +void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); +} + +// Interface registration hooks for individual upstream dialects. +void registerAffineDialectAutoDiffInterface(DialectRegistry ®istry); void registerArithDialectAutoDiffInterface(DialectRegistry ®istry); void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry); void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); +void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 71d3db0c2905..bab415ff247f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -19,99 +19,18 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include using namespace mlir; using namespace mlir::enzyme; namespace { -struct ForOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto forOp = cast(op); - auto nFor = cast(gutils->getNewFromOriginal(op)); - // Get a list of all the return types, which is the original return types - // alongside any shadow return types - SmallVector nTypes; - for (auto r : forOp->getResults()) { - // TODO only if used - nTypes.push_back(r.getType()); - if (!gutils->isConstantValue(r)) { - auto adTypeIface = r.getType().dyn_cast(); - if (!adTypeIface) - return failure(); - nTypes.push_back(adTypeIface.getShadowType()); - } - } - - // Get a list of all args, which is original args, and any shadows - SmallVector nArgs; - for (const auto &[initVal, iterArg] : - llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs())) { - // TODO only if used - nArgs.push_back(gutils->getNewFromOriginal(initVal)); - if (!gutils->isConstantValue(iterArg)) - nArgs.push_back(gutils->invertPointerM(initVal, builder)); - } - - // Create the new modified for loop - auto repFor = builder.create( - forOp.getLoc(), gutils->getNewFromOriginal(forOp.getLowerBound()), - gutils->getNewFromOriginal(forOp.getUpperBound()), - gutils->getNewFromOriginal(forOp.getStep()), nArgs); - repFor.getRegion().takeBody(nFor.getRegion()); - - // Inject the mapping for the new results into GradientUtil's shadow - // table - SmallVector reps; - size_t idx = 0; - for (Value r : forOp.getResults()) { - // TODO only if used - reps.push_back(repFor.getResult(idx)); - idx++; - if (!gutils->isConstantValue(r)) { - auto inverted = gutils->invertedPointers.lookupOrNull(r); - assert(inverted); - gutils->invertedPointers.map(r, repFor.getResult(idx)); - inverted.replaceAllUsesWith(repFor.getResult(idx)); - gutils->erase(inverted.getDefiningOp()); - idx++; - } - } - - // Replace all uses of original results - nFor.replaceAllUsesWith(reps); - gutils->erase(nFor); - - // differentiate body - for (Operation &o : - llvm::make_early_inc_range(forOp.getBody()->without_terminator())) { - if (failed(gutils->visitChild(&o))) - return failure(); - } - - // Fix terminator (yield) operations - Operation *oldYield = repFor.getBody()->getTerminator(); - builder.setInsertionPointToEnd(repFor.getBody()); - SmallVector nYields; - for (const auto &[result, yieldOperand] : - llvm::zip(forOp.getResults(), - forOp.getBody()->getTerminator()->getOperands())) { - // TODO only if used - nYields.push_back(gutils->getNewFromOriginal(yieldOperand)); - if (!gutils->isConstantValue(result)) - nYields.push_back(gutils->invertPointerM(yieldOperand, builder)); - } - Operation *newYield = builder.clone(*oldYield); - newYield->setOperands(nYields); - gutils->erase(oldYield); - - // Done - return success(); - } -}; +#include "Implementations/SCFDerivatives.inc" struct ForOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel(*context); - + registerInterfaces(context); scf::ForOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td new file mode 100644 index 000000000000..4c9ee09abcd2 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td @@ -0,0 +1,50 @@ +include "Common.td" + +def : ControlFlowOp<"scf", "ForOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + scf::ForOpAdaptor adaptor(remappedOperands); + auto repFor = builder.create( + op->getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), adaptor.getInitArgs()); + return repFor; + } +}]>; + +def : ControlFlowOp<"scf", "IfOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + scf::IfOpAdaptor adaptor(remappedOperands); + auto repIf = builder.create( + op->getLoc(), rettys, adaptor.getCondition()); + return repIf; + } +}]>; + +def : ControlFlowOp<"scf", "WhileOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return builder.create(original->getLoc(), rettys, + remappedOperands, original->getAttrs()); + } +}]>; + +def : ControlFlowOp<"scf", "ExecuteRegionOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + auto repIf = builder.create( + op->getLoc(), rettys); + return repIf; + } +}]>; + +def : RegionTerminatorOp<"scf", "YieldOp">; +def : RegionTerminatorOp<"scf", "ConditionOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index f45b641f9cb8..5d764c8dc7bb 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -40,6 +40,42 @@ def AutoDiffOpInterface : OpInterface<"AutoDiffOpInterface"> { ]; } +def ControlFlowAutoDiffOpInterface + : OpInterface<"ControlFlowAutoDiffOpInterface"> { + let description = [{ + A differentiable MLIR operation whose forward differentiation rules are + driven by how control flows through the operation. + + There are two general assumptions: + - the operation can communicate additional values along the control flow + edges, which is used to put shadow values immediately after the primal + values; + - all values returned by the operation are yielded by all region-exiting + terminators. + }]; + let cppNamespace = "::mlir::enzyme"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Creates a copy of this operation additionally carrying required shadow + values along control flow edges using the given builder. The `original` + is the operation in the original primal code prior to differentiation, + and this method is supposed to be called on the operation in the cloned + function being constructed. Remapped operands contains a flat list of + operands usable in the cloned function and can be fed to the Adaptor + constructor. + }], + /*retTy=*/"::mlir::Operation *", + /*methodName=*/"createWithShadows", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::enzyme::MGradientUtils *":$gutils, + "::mlir::Operation *":$original, + "::mlir::ValueRange":$remappedOperands, + "::mlir::TypeRange":$returnTypes) + > + ]; +} + def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { let description = [{ A differentiable MLIR operation that is able to emit reverse mode adjoints for itself. diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index a92a5f3386ff..1da6904893f8 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -1,4 +1,5 @@ #include "Dialect/Ops.h" +#include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" @@ -21,7 +22,7 @@ using namespace mlir; using namespace mlir::enzyme; -void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB, +void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, DIFFE_TYPE retType, ReturnType retVal) { auto inst = oBB->getTerminator(); @@ -33,39 +34,7 @@ void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB, nBuilder.setInsertionPointToEnd(nBB); if (auto binst = dyn_cast(inst)) { - // TODO generalize to cloneWithNewBlockArgs interface - SmallVector newVals; - - SmallVector segSizes; - for (size_t i = 0, len = binst.getSuccessorOperands(0) - .getForwardedOperands() - .getBeginOperandIndex(); - i < len; i++) - newVals.push_back(gutils->getNewFromOriginal(binst->getOperand(i))); - segSizes.push_back(newVals.size()); - for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { - size_t cur = newVals.size(); - for (auto op : binst.getSuccessorOperands(i).getForwardedOperands()) { - newVals.push_back(gutils->getNewFromOriginal(op)); - if (!gutils->isConstantValue(op)) { - newVals.push_back(gutils->invertPointerM(op, nBuilder)); - } - } - cur = newVals.size() - cur; - segSizes.push_back(cur); - } - - SmallVector attrs(newInst->getAttrs()); - for (auto &attr : attrs) { - if (attr.getName() == "operandSegmentSizes") - attr.setValue(nBuilder.getDenseI32ArrayAttr(segSizes)); - } - - nBB->push_back( - newInst->create(newInst->getLoc(), newInst->getName(), TypeRange(), - newVals, attrs, OpaqueProperties(nullptr), - newInst->getSuccessors(), newInst->getNumRegions())); - gutils->erase(newInst); + mlir::enzyme::detail::branchingForwardHandler(inst, nBuilder, gutils); return; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 286456b2d039..20ed156d312d 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -306,7 +306,9 @@ LogicalResult MGradientUtils::visitChild(Operation *op) { // In absence of a proper activity analysis, approximate it by treating any // side effect-free operation producing constants as inactive. // if (auto iface = dyn_cast(op)) { - if (llvm::all_of(op->getResults(), + if (!isa(op) && + !isa(op) && + llvm::all_of(op->getResults(), [this](Value v) { return isConstantValue(v); }) && /*iface.hasNoEffect()*/ activityAnalyzer->isConstantOperation(TR, op)) { return success(); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 094228232346..7ef99a1014fd 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -54,6 +54,14 @@ class MGradientUtils { std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp); void erase(Operation *op) { op->erase(); } + void replaceOrigOpWith(Operation *op, ValueRange vals) { + for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) { + originalToNewFn.map(res, rep); + } + auto newOp = getNewFromOriginal(op); + newOp->replaceAllUsesWith(vals); + originalToNewFnOps.erase(op); + } void eraseIfUnused(Operation *op, bool erase = true, bool check = true) { // TODO } diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index b6bb33c51df9..e7af0b01267f 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -97,12 +97,14 @@ int main(int argc, char **argv) { }); // Register the autodiff interface implementations for upstream dialects. + enzyme::registerAffineDialectAutoDiffInterface(registry); enzyme::registerArithDialectAutoDiffInterface(registry); enzyme::registerBuiltinDialectAutoDiffInterface(registry); enzyme::registerLLVMDialectAutoDiffInterface(registry); enzyme::registerNVVMDialectAutoDiffInterface(registry); enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); + enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir new file mode 100644 index 000000000000..a091d6d774ac --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -0,0 +1,104 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @loop(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %r = affine.for %arg1 = 0 to 10 step 1 iter_args(%arg2 = %cst) -> (f64) { + %n = arith.addf %arg2, %x : f64 + affine.yield %n : f64 + } + return %r : f64 + } + func.func @dloop(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @loop(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeloop + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, f64) { + // CHECK: %[[v1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 + // CHECK: %[[v2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 + // CHECK: affine.yield %[[v2]], %[[v1]] : f64, f64 + // CHECK: } + // CHECK: return %[[r0]]#1 : f64 + + func.func @if_then_else(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dif_then_else(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @if_then_else(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeif_then_else + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) + // CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, f64, f64) { + // CHECK: %[[v3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v5:.+]] = arith.addf %[[v3]], %[[v4]] : f64 + // CHECK: %[[v6:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v6]], %[[v5]], %[[cst]] : f64, f64, f64 + // CHECK: } else { + // CHECK: %[[v3:.+]] = arith.addf %[[arg1]], %[[arg1]] : f64 + // CHECK: %[[v4:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v4]], %[[v3]], %[[cst_0]] : f64, f64, f64 + // CHECK: } + // CHECK: %[[v1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 + // CHECK: %[[v2:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 + // CHECK: return %[[v1]] : f64 + + func.func @if_then(%x : f64, %c : i1) -> f64 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %mem = memref.alloc() : memref<1xf64> + memref.store %c2, %mem[%c0] : memref<1xf64> + scf.if %c { + %mul = arith.mulf %x, %x : f64 + memref.store %mul, %mem[%c0] : memref<1xf64> + } + %r = memref.load %mem[%c0] : memref<1xf64> + %res = arith.mulf %c2, %r : f64 + return %res : f64 + } + func.func @dif_then(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @if_then(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeif_then + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { + // CHECK: %[[c0:.+]] = arith.constant 0 : index + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[cst_1:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[alloc:.+]] = memref.alloc() : memref<1xf64> + // CHECK: %[[alloc_2:.+]] = memref.alloc() : memref<1xf64> + // CHECK: memref.store %[[cst]], %[[alloc]][%[[c0]]] : memref<1xf64> + // CHECK: memref.store %[[cst_0]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: scf.if %[[arg2]] { + // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v5:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v6:.+]] = arith.addf %[[v4]], %[[v5]] : f64 + // CHECK: %[[v7:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 + // CHECK: memref.store %[[v6]], %[[alloc]][%[[c0]]] : memref<1xf64> + // CHECK: memref.store %[[v7]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: } + // CHECK: %[[v0:.+]] = memref.load %[[alloc]][%[[c0]]] : memref<1xf64> + // CHECK: %[[v1:.+]] = memref.load %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst_0]] : f64 + // CHECK: %[[v3:.+]] = arith.mulf %[[cst_0]], %[[v1]] : f64 + // CHECK: return %[[v2]] : f64 +} diff --git a/enzyme/test/MLIR/branch-self-recursive.mlir b/enzyme/test/MLIR/ForwardMode/branch-self-recursive.mlir similarity index 100% rename from enzyme/test/MLIR/branch-self-recursive.mlir rename to enzyme/test/MLIR/ForwardMode/branch-self-recursive.mlir diff --git a/enzyme/test/MLIR/branch.mlir b/enzyme/test/MLIR/ForwardMode/branch.mlir similarity index 100% rename from enzyme/test/MLIR/branch.mlir rename to enzyme/test/MLIR/ForwardMode/branch.mlir diff --git a/enzyme/test/MLIR/ForwardMode/executeop.mlir b/enzyme/test/MLIR/ForwardMode/executeop.mlir new file mode 100644 index 000000000000..696b07b5f7ea --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/executeop.mlir @@ -0,0 +1,60 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i32) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.execute_region -> (f64, f64) { + cf.switch %c : i32, [ + default: ^bb5, + 104: ^bb3, + 113: ^bb4(%c10 : f64) + ] + ^bb4(%z : f64): // pred: ^bb2 + %x2 = arith.mulf %x, %x : f64 + scf.yield %x2, %z : f64, f64 + ^bb3: + %x3 = arith.addf %x, %x : f64 + scf.yield %x3, %c2 : f64, f64 + ^bb5: + cf.br ^bb4(%x : f64) + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : f64, %c : i32) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i32) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[x:.+]]: f64, %[[dx:.+]]: f64, %[[c:.+]]: i32) -> f64 { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-DAG: %[[cst0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[r0:.+]]:4 = scf.execute_region -> (f64, f64, f64, f64) { +// CHECK-NEXT: %[[cst02:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: cf.switch %[[c]] : i32, [ +// CHECK-NEXT: default: ^bb3, +// CHECK-NEXT: 104: ^bb2, +// CHECK-NEXT: 113: ^bb1(%[[cst10]], %[[cst02]] : f64, f64) +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1(%[[a3:.+]]: f64, %[[da3:.+]]: f64): // 2 preds: ^bb0, ^bb3 +// CHECK-NEXT: %[[a4:.+]] = arith.mulf %[[dx]], %[[x]] : f64 +// CHECK-NEXT: %[[a5:.+]] = arith.mulf %[[dx]], %[[x]] : f64 +// CHECK-NEXT: %[[a6:.+]] = arith.addf %[[a4]], %[[a5]] : f64 +// CHECK-NEXT: %[[a7:.+]] = arith.mulf %[[x]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[a7]], %[[a6]], %[[a3]], %[[da3]] : f64, f64, f64, f64 +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: %[[b8:.+]] = arith.addf %[[dx]], %[[dx]] : f64 +// CHECK-NEXT: %[[b9:.+]] = arith.addf %[[x]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[b9]], %[[b8]], %[[cst2]], %[[cst0]] : f64, f64, f64, f64 +// CHECK-NEXT: ^bb3: // pred: ^bb0 +// CHECK-NEXT: cf.br ^bb1(%[[x]], %[[dx]] : f64, f64) +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#3, %[[r0]]#0 : f64 +// CHECK-NEXT: %[[r3:.+]] = arith.addf %[[r1]], %[[r2]] : f64 +// CHECK-NEXT: %[[r4:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r3]] : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/for.mlir b/enzyme/test/MLIR/ForwardMode/for.mlir similarity index 100% rename from enzyme/test/MLIR/for.mlir rename to enzyme/test/MLIR/ForwardMode/for.mlir diff --git a/enzyme/test/MLIR/ForwardMode/for2.mlir b/enzyme/test/MLIR/ForwardMode/for2.mlir new file mode 100644 index 000000000000..7d4e7608b98e --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/for2.mlir @@ -0,0 +1,30 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %x) -> (f64) { + %n = arith.addf %arg2, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[arg0]], %[[arg4:.+]] = %[[arg1]]) -> (f64, f64) { +// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[i0]]#1 : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/if1.mlir b/enzyme/test/MLIR/ForwardMode/if1.mlir new file mode 100644 index 000000000000..3187bac56fa7 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/if1.mlir @@ -0,0 +1,41 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } +} + + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, f64, f64) { +// CHECK-NEXT: %[[t3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[t4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[t5:.+]] = arith.addf %[[t3]], %[[t4]] : f64 +// CHECK-NEXT: %[[t6:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[t6]], %[[t5]], %[[cst2]] : f64, f64, f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[e3:.+]] = arith.addf %arg1, %arg1 : f64 +// CHECK-NEXT: %[[e4:.+]] = arith.addf %arg0, %arg0 : f64 +// CHECK-NEXT: scf.yield %[[e4]], %[[e3]], %[[cst10]] : f64, f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r1]] : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/inactive.mlir b/enzyme/test/MLIR/ForwardMode/inactive.mlir similarity index 100% rename from enzyme/test/MLIR/inactive.mlir rename to enzyme/test/MLIR/ForwardMode/inactive.mlir diff --git a/enzyme/test/MLIR/invalid.mlir b/enzyme/test/MLIR/ForwardMode/invalid.mlir similarity index 100% rename from enzyme/test/MLIR/invalid.mlir rename to enzyme/test/MLIR/ForwardMode/invalid.mlir diff --git a/enzyme/test/MLIR/llvm.mlir b/enzyme/test/MLIR/ForwardMode/llvm.mlir similarity index 100% rename from enzyme/test/MLIR/llvm.mlir rename to enzyme/test/MLIR/ForwardMode/llvm.mlir diff --git a/enzyme/test/MLIR/memref.mlir b/enzyme/test/MLIR/ForwardMode/memref.mlir similarity index 100% rename from enzyme/test/MLIR/memref.mlir rename to enzyme/test/MLIR/ForwardMode/memref.mlir diff --git a/enzyme/test/MLIR/test.mlir b/enzyme/test/MLIR/ForwardMode/test.mlir similarity index 100% rename from enzyme/test/MLIR/test.mlir rename to enzyme/test/MLIR/ForwardMode/test.mlir diff --git a/enzyme/test/MLIR/ForwardMode/while.mlir b/enzyme/test/MLIR/ForwardMode/while.mlir new file mode 100644 index 000000000000..cbe7da34769b --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/while.mlir @@ -0,0 +1,44 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @while(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %r:2 = scf.while (%arg1 = %c0, %arg2 = %cst) : (index, f64) -> (index, f64) { + %1 = arith.cmpi slt, %arg1, %c10 : index + scf.condition(%1) %arg1, %arg2 : index, f64 + } do { + ^bb0(%arg1: index, %arg2: f64): + %1 = arith.addi %arg1, %c1 : index + %2 = arith.addf %arg2, %x : f64 + scf.yield %1, %2 : index, f64 + } + return %r#1 : f64 + } + func.func @dwhile(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @while(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffewhile + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[c0:.+]] = arith.constant 0 : index + // CHECK: %[[c1:.+]] = arith.constant 1 : index + // CHECK: %[[c10:.+]] = arith.constant 10 : index + // CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) : (index, f64, f64) -> (index, f64, f64) { + // CHECK: %[[v1:.+]] = arith.cmpi slt, %[[arg2]], %[[c10]] : index + // CHECK: scf.condition(%[[v1]]) %[[arg2]], %[[arg3]], %[[arg4]] : index, f64, f64 + // CHECK: } do { + // CHECK: ^bb0(%[[arg2:.+]]: index, %[[arg3:.+]]: f64, %[[arg4:.+]]: f64): + // CHECK: %[[v1:.+]] = arith.addi %[[arg2]], %[[c1]] : index + // CHECK: %[[v2:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 + // CHECK: %[[v3:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v1]], %[[v3]], %[[v2]] : index, f64, f64 + // CHECK: } + // CHECK: return %[[r0]]#2 : f64 + // CHECK: } +} diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 1f4a3dc1b892..58802117bf5b 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1888,6 +1888,25 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, "return true; }\n"; os << "};\n"; } + const auto &cfpatterns = + recordKeeper.getAllDerivedDefinitions("ControlFlowOp"); + for (auto &pattern : cfpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + auto impl = pattern->getValueAsString("impl"); + os << "struct " << opName << "CF : \n"; + os << " public " + "ControlFlowAutoDiffOpInterface::ExternalModel<" + << opName << "CF, " << dialect << "::" << opName << "> {\n"; + os << impl << "\n"; + os << "};\n"; + } + + const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp"); + + const auto ®tpatterns = + recordKeeper.getAllDerivedDefinitions("RegionTerminatorOp"); + os << "void registerInterfaces(MLIRContext* context) {\n"; for (Record *pattern : patterns) { auto opName = pattern->getValueAsString("opName"); @@ -1903,6 +1922,26 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "Activity>(*context);\n"; } + for (Record *pattern : cfpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "CF>(*context);\n"; + os << " registerAutoDiffUsingControlFlowInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : brpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingBranchInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : regtpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } os << "}\n"; } } From c8b777ef79a4151685bc0bbd148b149c49ceacfc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 1 Feb 2024 00:26:45 -0500 Subject: [PATCH 004/106] [MLIR] Add read interface fwd (#1652) --- enzyme/BUILD | 30 ++++++++++++++++ .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 24 ++++++++++--- .../MLIR/Implementations/AffineDerivatives.td | 2 ++ .../MLIR/Implementations/CMakeLists.txt | 5 +++ enzyme/Enzyme/MLIR/Implementations/Common.td | 6 ++++ .../CoreDialectsAutoDiffImplementations.cpp | 31 +++++++++++++++++ .../CoreDialectsAutoDiffImplementations.h | 24 +++++++++++++ .../LLVMAutoDiffOpInterfaceImpl.cpp | 23 +------------ .../MLIR/Implementations/LLVMDerivatives.td | 8 +++++ .../MemRefAutoDiffOpInterfaceImpl.cpp | 22 ++---------- .../MLIR/Implementations/MemRefDerivatives.td | 13 +++++++ .../MLIR/Interfaces/AutoDiffOpInterface.td | 2 +- enzyme/test/MLIR/ForwardMode/affine.mlir | 6 ++-- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 34 ++++++++++++++++++- 14 files changed, 179 insertions(+), 51 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td diff --git a/enzyme/BUILD b/enzyme/BUILD index f091b06bd8f5..7645c3a7ecf9 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -456,6 +456,34 @@ gentbl( ], ) +gentbl( + name = "cf-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/CFDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/CFDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/CFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "memref-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/MemRefDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/MemRefDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/MemRefDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -482,6 +510,8 @@ cc_library( ":llvm-derivatives", ":nvvm-derivatives", ":scf-derivatives", + ":cf-derivatives", + ":memref-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index bb8c8f2ccc9c..5f3cd22db72d 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -2820,8 +2820,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (UA != UseActivity::AllStores) { if (auto ifaceOp = dyn_cast(a)) { - if (ifaceOp.isArgInactive(parent)) - return true; + bool allInactive = true; + for (OpOperand &operand : a->getOpOperands()) { + if (parent == operand.get() && + !ifaceOp.isArgInactive(operand.getOperandNumber())) { + allInactive = false; + break; + } + } + if (allInactive) + continue; } } @@ -3394,8 +3402,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( } if (auto ifaceOp = dyn_cast(a)) { - if (ifaceOp.isArgInactive(val)) - return true; + bool allInactive = true; + for (OpOperand &operand : a->getOpOperands()) { + if (operand.get() == val && + !ifaceOp.isArgInactive(operand.getOperandNumber())) { + allInactive = false; + break; + } + } + if (allInactive) + continue; } if (isa(a)) { diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td index d6e866f3089e..aaad67ff367a 100644 --- a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td @@ -22,3 +22,5 @@ def : ControlFlowOp<"affine", "AffineIfOp", [{ }]>; def : RegionTerminatorOp<"affine", "AffineYieldOp">; +def : ReadOnlyIdentityOp<"affine", "AffineLoadOp", [0]>; +def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index 8e62f5849be1..464740d48d4b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -22,6 +22,10 @@ set(LLVM_TARGET_DEFINITIONS CFDerivatives.td) enzyme_tablegen(CFDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(CFDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS MemRefDerivatives.td) +enzyme_tablegen(MemRefDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(MemRefDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp @@ -42,6 +46,7 @@ add_mlir_library(MLIREnzymeImplementations NVVMDerivativesIncGen SCFDerivativesIncGen CFDerivativesIncGen + MemRefDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index c09cac32f734..4eff61d7380f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -9,6 +9,12 @@ class ControlFlowOp { string impl = impl_; } +class ReadOnlyIdentityOp diffargs_> { + string dialect = dialect_; + string opName = opName_; + list diffargs = diffargs_; +} + class BranchOp { string dialect = dialect_; string opName = opName_; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 2a50d1c481c0..0ee33c44f01c 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -101,6 +101,37 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, return; } +LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils) { + + auto iface = cast(orig); + + SmallVector newOperands; + newOperands.reserve(orig->getNumOperands()); + for (OpOperand &operand : orig->getOpOperands()) { + if (iface.isArgInactive(operand.getOperandNumber())) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + } else { + if (gutils->isConstantValue(operand.get())) + return failure(); + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + } + + // Assuming shadows following the originals are fine. + // TODO: consider extending to have a ShadowableTerminatorOpInterface + Operation *primal = gutils->getNewFromOriginal(orig); + Operation *shadow = builder.clone(*primal); + shadow->setOperands(newOperands); + for (auto &&[oval, sval] : + llvm::zip(orig->getResults(), shadow->getResults())) { + gutils->setDiffe(oval, sval, builder); + } + llvm::errs() << " shadow load: " << *shadow << "\n"; + + return success(); +} + void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { auto termIface = cast(origTerminator); diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index b98a6ef10c39..4c54baa24ed2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -39,6 +39,11 @@ void branchingForwardHandler(Operation *op, OpBuilder &builder, void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); +// Implements forward-mode differentiation of read-only (including read-none) +// operations which do not perform computatoin +LogicalResult readOnlyIdentityForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + // Implements the forward autodiff interface for operations whose derivatives // are can be inferred by analyzing their control flow and differentiating the // nested operations. @@ -80,6 +85,19 @@ class AutoDiffUsingRegionTerminator return success(); } }; + +// Implements the forward autodiff interface for operations which are +// read only and identity like (aka not computing sin of mem read). +template +class AutoDiffUsingReadOnlyIdentity + : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingReadOnlyIdentity, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + return readOnlyIdentityForwardHandler(op, builder, gutils); + } +}; } // namespace detail // Registers AutoDiffUsingControlFlow for the given op. @@ -99,6 +117,12 @@ void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { OpTy::template attachInterface>( context); } +// Registers AutoDiffUsingRegionTerminator for the given op. +template +void registerAutoDiffUsingReadOnlyIdentityInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); +} // Interface registration hooks for individual upstream dialects. void registerAffineDialectAutoDiffInterface(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 079dd1cb64e9..d7884fcc305e 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -35,27 +35,7 @@ struct InlineAsmActivityInterface auto str = asmOp.getAsmString(); return str.contains("cpuid") || str.contains("exit"); } - bool isArgInactive(Operation *op, mlir::Value) const { - return isInactive(op); - } -}; - -struct LoadOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto loadOp = cast(op); - if (!gutils->isConstantValue(loadOp)) { - Type shadowType = - cast(loadOp.getType()).getShadowType(); - mlir::Value res = builder.create( - loadOp.getLoc(), shadowType, - gutils->invertPointerM(loadOp.getAddr(), builder)); - gutils->setDiffe(loadOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } + bool isArgInactive(Operation *op, size_t) const { return isInactive(op); } }; struct StoreOpInterface @@ -115,7 +95,6 @@ class PointerTypeInterface void mlir::enzyme::registerLLVMDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) { - LLVM::LoadOp::attachInterface(*context); LLVM::StoreOp::attachInterface(*context); LLVM::AllocaOp::attachInterface(*context); LLVM::LLVMPointerType::attachInterface(*context); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td index 9e5f28e41665..c9b8fe76ca56 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -15,3 +15,11 @@ def : InactiveOp<"LLVM", "MemsetOp">; def : InactiveOp<"LLVM", "UndefOp">; def : InactiveOp<"LLVM", "ConstantOp">; def : InactiveOp<"LLVM", "UnreachableOp">; + + +def : ReadOnlyIdentityOp<"LLVM", "LoadOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "AddrSpaceCastOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "BitcastOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "GEPOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 010f9d997005..04f8c3a60e33 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -28,25 +28,7 @@ using namespace mlir; using namespace mlir::enzyme; namespace { -struct LoadOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto loadOp = cast(op); - if (!gutils->isConstantValue(loadOp)) { - SmallVector inds; - for (auto ind : loadOp.getIndices()) - inds.push_back(gutils->getNewFromOriginal(ind)); - mlir::Value res = builder.create( - loadOp.getLoc(), gutils->invertPointerM(loadOp.getMemref(), builder), - inds); - gutils->setDiffe(loadOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; +#include "Implementations/MemRefDerivatives.inc" struct StoreOpInterface : public AutoDiffOpInterface::ExternalModel(*context); + registerInterfaces(context); memref::StoreOp::attachInterface(*context); memref::AllocOp::attachInterface(*context); MemRefType::attachInterface(*context); diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td new file mode 100644 index 000000000000..7bf269199ca4 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td @@ -0,0 +1,13 @@ +include "Common.td" + +def : ReadOnlyIdentityOp<"memref", "LoadOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "CastOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "CollapseShapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ExpandShapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ReinterpretCastOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ReshapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "TransposeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ViewOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "SubViewOp", [0]>; + +def : InactiveOp<"memref", "DimOp">; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 5d764c8dc7bb..9123dfcaa22e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -128,7 +128,7 @@ def ActivityOpInterface }], /*retTy=*/"bool", /*methodName=*/"isArgInactive", - /*args=*/(ins "::mlir::Value":$val) + /*args=*/(ins "size_t":$opidx) > ]; } diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir index a091d6d774ac..2dcb61441251 100644 --- a/enzyme/test/MLIR/ForwardMode/affine.mlir +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -70,7 +70,7 @@ module { %mul = arith.mulf %x, %x : f64 memref.store %mul, %mem[%c0] : memref<1xf64> } - %r = memref.load %mem[%c0] : memref<1xf64> + %r = affine.load %mem[0] : memref<1xf64> %res = arith.mulf %c2, %r : f64 return %res : f64 } @@ -96,8 +96,8 @@ module { // CHECK: memref.store %[[v6]], %[[alloc]][%[[c0]]] : memref<1xf64> // CHECK: memref.store %[[v7]], %[[alloc_2]][%[[c0]]] : memref<1xf64> // CHECK: } - // CHECK: %[[v0:.+]] = memref.load %[[alloc]][%[[c0]]] : memref<1xf64> - // CHECK: %[[v1:.+]] = memref.load %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: %[[v0:.+]] = affine.load %[[alloc]][0] : memref<1xf64> + // CHECK: %[[v1:.+]] = affine.load %[[alloc_2]][0] : memref<1xf64> // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst_0]] : f64 // CHECK: %[[v3:.+]] = arith.mulf %[[cst_0]], %[[v1]] : f64 // CHECK: return %[[v2]] : f64 diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 58802117bf5b..f3f8f5940d27 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1884,12 +1884,16 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " public ActivityOpInterface::ExternalModel<" << opName << "Activity, " << dialect << "::" << opName << "> {\n"; os << " bool isInactive(mlir::Operation*) const { return true; }\n"; - os << " bool isArgInactive(mlir::Operation*, mlir::Value) const { " + os << " bool isArgInactive(mlir::Operation*, size_t) const { " "return true; }\n"; os << "};\n"; } const auto &cfpatterns = recordKeeper.getAllDerivedDefinitions("ControlFlowOp"); + + const auto &ropatterns = + recordKeeper.getAllDerivedDefinitions("ReadOnlyIdentityOp"); + for (auto &pattern : cfpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); @@ -1902,6 +1906,26 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << "};\n"; } + for (auto &pattern : ropatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + auto diffargs = pattern->getValueAsListOfInts("diffargs"); + os << "struct " << opName << "ROActivity : \n"; + os << " public ActivityOpInterface::ExternalModel<" << opName + << "ROActivity, " << dialect << "::" << opName << "> {\n"; + os << " bool isInactive(mlir::Operation* op) const {\n"; + os << " for (size_t i=0, len=op->getNumOperands(); i(*context);\n"; } + for (Record *pattern : ropatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "ROActivity>(*context);\n"; + os << " registerAutoDiffUsingReadOnlyIdentityInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } for (Record *pattern : brpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); From 9ff56bc24fa1d483ac1abb08bfaaddfa6272d366 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 2 Feb 2024 23:26:01 -0500 Subject: [PATCH 005/106] [MLIR] General store / allocation / tensor / math interfaces (#1657) * Generalize mem to include stores * alloca * Add tensor/math --- enzyme/BUILD | 15 ++++ .../MLIR/Implementations/AffineDerivatives.td | 4 +- .../MLIR/Implementations/ArithDerivatives.td | 12 --- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 37 ++++++++ .../MLIR/Implementations/CMakeLists.txt | 6 ++ enzyme/Enzyme/MLIR/Implementations/Common.td | 30 ++++++- .../CoreDialectsAutoDiffImplementations.cpp | 59 +++++++++++- .../CoreDialectsAutoDiffImplementations.h | 73 ++++++++++++--- .../LLVMAutoDiffOpInterfaceImpl.cpp | 41 ++------- .../MLIR/Implementations/LLVMDerivatives.td | 5 +- .../MathAutoDiffOpInterfaceImpl.cpp | 38 ++++++++ .../MLIR/Implementations/MathDerivatives.td | 17 ++++ .../MemRefAutoDiffOpInterfaceImpl.cpp | 90 ++++--------------- .../MLIR/Implementations/MemRefDerivatives.td | 5 +- .../MLIR/Interfaces/AutoDiffTypeInterface.h | 1 + .../MLIR/Interfaces/AutoDiffTypeInterface.td | 8 ++ .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 7 +- enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 1 + enzyme/test/MLIR/ForwardMode/affine.mlir | 24 +++-- enzyme/test/MLIR/ForwardMode/tensorsin.mlir | 19 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 43 ++++++--- 21 files changed, 370 insertions(+), 165 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td create mode 100644 enzyme/test/MLIR/ForwardMode/tensorsin.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index 7645c3a7ecf9..3b4d03c001e0 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -484,6 +484,20 @@ gentbl( ], ) +gentbl( + name = "math-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/MathDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/MathDerivatives.td", + td_srcs = ["Enzyme/MLIR/Implementations/MathDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -512,6 +526,7 @@ cc_library( ":scf-derivatives", ":cf-derivatives", ":memref-derivatives", + ":math-derivatives", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", ":EnzymeTypesIncGen", diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td index aaad67ff367a..9f22d00cecdb 100644 --- a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td @@ -23,4 +23,6 @@ def : ControlFlowOp<"affine", "AffineIfOp", [{ def : RegionTerminatorOp<"affine", "AffineYieldOp">; def : ReadOnlyIdentityOp<"affine", "AffineLoadOp", [0]>; -def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>; \ No newline at end of file +def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>; +def : MemoryIdentityOp<"affine", "AffineStoreOp", [1], [0]>; +def : MemoryIdentityOp<"affine", "AffineVectorStoreOp", [1], [0]>; diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index fb7f113f16fe..3ed038fa4cb6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -1,17 +1,5 @@ include "Common.td" -class ArithInst : Inst; - -def AddF : ArithInst<"arith::AddFOp">; -def SubF : ArithInst<"arith::SubFOp">; -def NegF : ArithInst<"arith::NegFOp">; -def MulF : ArithInst<"arith::MulFOp">; -def DivF : ArithInst<"arith::DivFOp">; -def RemF : ArithInst<"arith::RemFOp">; - -def CheckedMulF : ArithInst<"arith::MulFOp">; -def CheckedDivF : ArithInst<"arith::DivFOp">; - def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y), [ (DiffeRet), diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 30feca7ba82d..613e0ef2ad7e 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -45,6 +45,37 @@ class FloatTypeInterface } bool requiresShadow(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } +}; + +class TensorTypeInterface + : public AutoDiffTypeInterface::ExternalModel { +public: + Value createNullValue(Type self, OpBuilder &builder, Location loc) const { + auto tenType = self.cast(); + auto attr = DenseElementsAttr::get(tenType, 0); + return builder.create(loc, tenType, attr); + } + + Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a, + Value b) const { + return builder.create(loc, a, b); + } + + Type getShadowType(Type self, unsigned width) const { + assert(width == 1 && "unsupported width != 1"); + return self; + } + + bool requiresShadow(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } }; template @@ -69,6 +100,10 @@ class IntegerTypeInterface } bool requiresShadow(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } }; } // namespace @@ -81,5 +116,7 @@ void mlir::enzyme::registerBuiltinDialectAutoDiffInterface( Float64Type::attachInterface(*context); IntegerType::attachInterface>(*context); IndexType::attachInterface>(*context); + UnrankedTensorType::attachInterface(*context); + RankedTensorType::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index 464740d48d4b..ab156fa1b36f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -26,6 +26,10 @@ set(LLVM_TARGET_DEFINITIONS MemRefDerivatives.td) enzyme_tablegen(MemRefDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(MemRefDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS MathDerivatives.td) +enzyme_tablegen(MathDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(MathDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp @@ -37,6 +41,7 @@ add_mlir_library(MLIREnzymeImplementations BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp CFAutoDiffOpInterfaceImpl.cpp + MathAutoDiffOpInterfaceImpl.cpp DEPENDS MLIRAutoDiffOpInterfaceIncGen @@ -47,6 +52,7 @@ add_mlir_library(MLIREnzymeImplementations SCFDerivativesIncGen CFDerivativesIncGen MemRefDerivativesIncGen + MathDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 4eff61d7380f..1451d99f17ca 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -3,18 +3,26 @@ class InactiveOp { string opName = opName_; } +class AllocationOp { + string dialect = dialect_; + string opName = opName_; +} + class ControlFlowOp { string dialect = dialect_; string opName = opName_; string impl = impl_; } -class ReadOnlyIdentityOp diffargs_> { +class MemoryIdentityOp ptrargs_, list storedargs_ = []> { string dialect = dialect_; string opName = opName_; - list diffargs = diffargs_; + list ptrargs = ptrargs_; + list storedargs = storedargs_; } +class ReadOnlyIdentityOp ptrargs_> : MemoryIdentityOp; + class BranchOp { string dialect = dialect_; string opName = opName_; @@ -50,3 +58,21 @@ class Inst : Operation : Inst; +class MathInst : Inst; + +def AddF : ArithInst<"arith::AddFOp">; +def SubF : ArithInst<"arith::SubFOp">; +def NegF : ArithInst<"arith::NegFOp">; +def MulF : ArithInst<"arith::MulFOp">; +def DivF : ArithInst<"arith::DivFOp">; +def RemF : ArithInst<"arith::RemFOp">; + +def CheckedMulF : ArithInst<"arith::MulFOp">; +def CheckedDivF : ArithInst<"arith::DivFOp">; + +def CosF : MathInst<"math::CosOp">; +def SinF : MathInst<"math::SinOp">; +def ExpF : MathInst<"math::ExpOp">; +def SqrtF : MathInst<"math::SqrtOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 0ee33c44f01c..857784472211 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -101,9 +101,18 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, return; } -LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( - Operation *orig, OpBuilder &builder, MGradientUtils *gutils) { +static bool contains(ArrayRef ar, int v) { + for (auto a : ar) { + if (a == v) { + return true; + } + } + return false; +} +LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils, + ArrayRef storedVals) { auto iface = cast(orig); SmallVector newOperands; @@ -112,8 +121,27 @@ LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( if (iface.isArgInactive(operand.getOperandNumber())) { newOperands.push_back(gutils->getNewFromOriginal(operand.get())); } else { - if (gutils->isConstantValue(operand.get())) + if (gutils->isConstantValue(operand.get())) { + + if (contains(storedVals, operand.getOperandNumber())) { + if (auto iface = + dyn_cast(operand.get().getType())) { + if (!iface.requiresShadow()) { + // TODO only do if mutable + Type retTy = iface.getShadowType(); + auto toret = retTy.cast().createNullValue( + builder, operand.get().getLoc()); + newOperands.push_back(toret); + continue; + } + } + } + orig->emitWarning() + << "Unsupported constant arg to memory identity forward " + "handler(opidx=" + << operand.getOperandNumber() << ", op=" << operand.get() << ")\n"; return failure(); + } newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); } } @@ -127,11 +155,34 @@ LogicalResult mlir::enzyme::detail::readOnlyIdentityForwardHandler( llvm::zip(orig->getResults(), shadow->getResults())) { gutils->setDiffe(oval, sval, builder); } - llvm::errs() << " shadow load: " << *shadow << "\n"; return success(); } +LogicalResult mlir::enzyme::detail::allocationForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils, bool zero) { + + Operation *primal = gutils->getNewFromOriginal(orig); + Operation *shadow = builder.clone(*primal); + + Value shadowRes = shadow->getResult(0); + + gutils->setDiffe(orig->getResult(0), shadowRes, builder); + gutils->eraseIfUnused(orig); + + if (zero) { + // Fill with zeros + if (auto iface = dyn_cast(shadowRes.getType())) { + return iface.zeroInPlace(builder, orig->getLoc(), shadowRes); + } else { + orig->emitWarning() << "memref.alloc element type does not implement " + "AutoDiffTypeInterface"; + return failure(); + } + } + return success(); +} + void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { auto termIface = cast(origTerminator); diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 4c54baa24ed2..5a1fbc136708 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -22,6 +22,7 @@ class OpBuilder; namespace enzyme { class MGradientUtils; +class MGradientUtilsReverse; namespace detail { // Non-template implementation of @@ -40,9 +41,14 @@ void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); // Implements forward-mode differentiation of read-only (including read-none) -// operations which do not perform computatoin -LogicalResult readOnlyIdentityForwardHandler(Operation *op, OpBuilder &builder, - MGradientUtils *gutils); +// operations which do not perform computation +LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, + ArrayRef storedVals); + +// Implements shadow initialization differentiation of allocation +LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, bool zero); // Implements the forward autodiff interface for operations whose derivatives // are can be inferred by analyzing their control flow and differentiating the @@ -88,14 +94,52 @@ class AutoDiffUsingRegionTerminator // Implements the forward autodiff interface for operations which are // read only and identity like (aka not computing sin of mem read). -template -class AutoDiffUsingReadOnlyIdentity +template +class AutoDiffUsingMemoryIdentity : public AutoDiffOpInterface::ExternalModel< - AutoDiffUsingReadOnlyIdentity, OpTy> { + AutoDiffUsingMemoryIdentity, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + + return memoryIdentityForwardHandler( + op, builder, gutils, std::initializer_list{storedvals...}); + } +}; + +// Implements the forward autodiff interface for operations which are +// allocation like +template +class AutoDiffUsingAllocationFwd : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingAllocationFwd, OpTy> { public: LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const { - return readOnlyIdentityForwardHandler(op, builder, gutils); + + return allocationForwardHandler(op, builder, gutils, /*zero*/ false); + } +}; + +// Implements the reverse autodiff interface for operations which are +// allocation like +template +class AutoDiffUsingAllocationRev + : public ReverseAutoDiffOpInterface::ExternalModel< + AutoDiffUsingAllocationRev, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const {} + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const { + (void)allocationForwardHandler(op, builder, (MGradientUtils *)gutils, + /*zero*/ true); } }; } // namespace detail @@ -117,10 +161,18 @@ void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { OpTy::template attachInterface>( context); } -// Registers AutoDiffUsingRegionTerminator for the given op. +// Registers AutoDiffUsingMemoryIdentity for the given op. +template +void registerAutoDiffUsingMemoryIdentityInterface(MLIRContext &context) { + OpTy::template attachInterface< + detail::AutoDiffUsingMemoryIdentity>(context); +} +// Registers AutoDiffUsingAllocation for the given op. template -void registerAutoDiffUsingReadOnlyIdentityInterface(MLIRContext &context) { - OpTy::template attachInterface>( +void registerAutoDiffUsingAllocationInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); + OpTy::template attachInterface>( context); } @@ -134,5 +186,6 @@ void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); +void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index d7884fcc305e..472cd0a43d61 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -24,9 +24,7 @@ using namespace mlir::enzyme; namespace { #include "Implementations/LLVMDerivatives.inc" -} // namespace -namespace { struct InlineAsmActivityInterface : public ActivityOpInterface::ExternalModel { @@ -38,37 +36,6 @@ struct InlineAsmActivityInterface bool isArgInactive(Operation *op, size_t) const { return isInactive(op); } }; -struct StoreOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto storeOp = cast(op); - if (!gutils->isConstantValue(storeOp.getAddr())) { - builder.create( - storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), - gutils->invertPointerM(storeOp.getAddr(), builder)); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - -struct AllocaOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto allocOp = cast(op); - if (!gutils->isConstantValue(allocOp)) { - Operation *nop = gutils->cloneWithNewOperands(builder, op); - gutils->setDiffe(allocOp, nop->getResult(0), builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - class PointerTypeInterface : public AutoDiffTypeInterface::ExternalModel { @@ -89,14 +56,18 @@ class PointerTypeInterface } bool requiresShadow(Type self) const { return true; } + + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + // TODO inspect val and memset corresponding size + return failure(); + } }; } // namespace void mlir::enzyme::registerLLVMDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) { - LLVM::StoreOp::attachInterface(*context); - LLVM::AllocaOp::attachInterface(*context); LLVM::LLVMPointerType::attachInterface(*context); registerInterfaces(context); }); diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td index c9b8fe76ca56..e77e88aea47f 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -1,5 +1,6 @@ include "Common.td" +def : MemoryIdentityOp<"LLVM", "StoreOp", [1], [0]>; def : InactiveOp<"LLVM", "SIToFPOp">; def : InactiveOp<"LLVM", "UIToFPOp">; def : InactiveOp<"LLVM", "FPToSIOp">; @@ -22,4 +23,6 @@ def : ReadOnlyIdentityOp<"LLVM", "AddrSpaceCastOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "BitcastOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "GEPOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>; -def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; \ No newline at end of file +def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; + +def : AllocationOp<"LLVM", "AllocaOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..2833eeb44726 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,38 @@ +//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/Ops.h" +#include "mlir/IR/TypeSupport.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/MathDerivatives.inc" +} // namespace + +void mlir::enzyme::registerMathDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, math::MathDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td new file mode 100644 index 000000000000..71db1a8574ac --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td @@ -0,0 +1,17 @@ +include "Common.td" + +def : MLIRDerivative<"math", "CosOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (NegF (SinF $x))) + ] + >; +def : MLIRDerivative<"math", "ExpOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (ExpF $x)) + ] + >; +def : MLIRDerivative<"math", "SinOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (CosF $x)) + ] + >; diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 04f8c3a60e33..1b811caf1b81 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -17,38 +17,20 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" -// TODO: We need a way to zero out a memref (which linalg.fill does), but -// ideally we wouldn't depend on the linalg dialect. -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" +// TODO: We need a way to zero out a memref (which linalg.fill does), but +// ideally we wouldn't depend on the linalg dialect. +#include "mlir/Dialect/Linalg/IR/Linalg.h" + using namespace mlir; using namespace mlir::enzyme; namespace { #include "Implementations/MemRefDerivatives.inc" -struct StoreOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto storeOp = cast(op); - if (!gutils->isConstantValue(storeOp.getMemref())) { - SmallVector inds; - for (auto ind : storeOp.getIndices()) - inds.push_back(gutils->getNewFromOriginal(ind)); - builder.create( - storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), - gutils->invertPointerM(storeOp.getMemref(), builder), inds); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - struct LoadOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { @@ -177,38 +159,6 @@ struct StoreOpInterfaceReverse } }; -struct AllocOpInterfaceReverse - : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const {} - - SmallVector cacheValues(Operation *op, - MGradientUtilsReverse *gutils) const { - return SmallVector(); - } - - void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const { - auto allocOp = cast(op); - auto newAllocOp = cast(gutils->getNewFromOriginal(op)); - - Value shadow = builder.create( - op->getLoc(), newAllocOp.getType(), newAllocOp.getDynamicSizes()); - // Fill with zeros - if (auto iface = dyn_cast( - allocOp.getType().getElementType())) { - Value zero = iface.createNullValue(builder, op->getLoc()); - builder.create(op->getLoc(), zero, shadow); - } else { - op->emitWarning() << "memref.alloc element type does not implement " - "AutoDiffTypeInterface"; - } - gutils->mapShadowValue(allocOp, shadow, builder); - } -}; - struct SubViewOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel< SubViewOpInterfaceReverse, memref::SubViewOp> { @@ -236,21 +186,6 @@ struct SubViewOpInterfaceReverse } }; -struct AllocOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto allocOp = cast(op); - if (!gutils->isConstantValue(allocOp)) { - Operation *nop = gutils->cloneWithNewOperands(builder, op); - gutils->setDiffe(allocOp, nop->getResult(0), builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - class MemRefTypeInterface : public AutoDiffTypeInterface::ExternalModel { @@ -271,6 +206,20 @@ class MemRefTypeInterface } bool requiresShadow(Type self) const { return true; } + + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + auto MT = cast(self); + if (auto iface = dyn_cast(MT.getElementType())) { + if (!iface.requiresShadow()) { + Value zero = iface.createNullValue(builder, loc); + builder.create(loc, zero, val); + } + } else { + return failure(); + } + return success(); + } }; } // namespace @@ -278,13 +227,10 @@ void mlir::enzyme::registerMemRefDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) { registerInterfaces(context); - memref::StoreOp::attachInterface(*context); - memref::AllocOp::attachInterface(*context); MemRefType::attachInterface(*context); memref::LoadOp::attachInterface(*context); memref::StoreOp::attachInterface(*context); - memref::AllocOp::attachInterface(*context); memref::SubViewOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td index 7bf269199ca4..173a22a6e2f1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td @@ -1,5 +1,6 @@ include "Common.td" +def : MemoryIdentityOp<"memref", "StoreOp", [1], [0]>; def : ReadOnlyIdentityOp<"memref", "LoadOp", [0]>; def : ReadOnlyIdentityOp<"memref", "CastOp", [0]>; def : ReadOnlyIdentityOp<"memref", "CollapseShapeOp", [0]>; @@ -10,4 +11,6 @@ def : ReadOnlyIdentityOp<"memref", "TransposeOp", [0]>; def : ReadOnlyIdentityOp<"memref", "ViewOp", [0]>; def : ReadOnlyIdentityOp<"memref", "SubViewOp", [0]>; -def : InactiveOp<"memref", "DimOp">; \ No newline at end of file +def : InactiveOp<"memref", "DimOp">; +def : AllocationOp<"memref", "AllocOp">; +def : AllocationOp<"memref", "AllocaOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h index 7c405bde0a99..04d0186c4b1e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h @@ -16,6 +16,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" namespace mlir { class OpBuilder; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td index 5fec2643a98c..956f616751dc 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td @@ -41,6 +41,14 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { /*methodName=*/"createAddOp", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, "::mlir::Value":$a, "::mlir::Value":$b) >, + InterfaceMethod< + /*desc=*/[{ + Zero the operation in place + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"zeroInPlace", + /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, "::mlir::Value":$val) + >, InterfaceMethod< /*desc=*/[{ Returns the type that can contain the adjoint value for this type. If diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 84b8b8b5d3b3..2b2760dfa8ff 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -4,7 +4,10 @@ using namespace mlir; using namespace mlir::enzyme; Type getShadowType(Type type, unsigned width) { - return type.cast().getShadowType(width); + if (auto iface = type.dyn_cast()) + return iface.getShadowType(width); + llvm::errs() << " type does not have autodifftypeinterface: " << type << "\n"; + exit(1); } mlir::FunctionType getFunctionTypeForClone( @@ -263,4 +266,4 @@ FunctionOpInterface CloneFunctionWithReturns( } return NewF; -} \ No newline at end of file +} diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index e7af0b01267f..90038c1e3236 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -102,6 +102,7 @@ int main(int argc, char **argv) { enzyme::registerBuiltinDialectAutoDiffInterface(registry); enzyme::registerLLVMDialectAutoDiffInterface(registry); enzyme::registerNVVMDialectAutoDiffInterface(registry); + enzyme::registerMathDialectAutoDiffInterface(registry); enzyme::registerMemRefDialectAutoDiffInterface(registry); enzyme::registerSCFDialectAutoDiffInterface(registry); enzyme::registerCFDialectAutoDiffInterface(registry); diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir index 2dcb61441251..d0e587409713 100644 --- a/enzyme/test/MLIR/ForwardMode/affine.mlir +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -61,14 +61,13 @@ module { // CHECK: return %[[v1]] : f64 func.func @if_then(%x : f64, %c : i1) -> f64 { - %c0 = arith.constant 0 : index %c2 = arith.constant 2.000000e+00 : f64 %c10 = arith.constant 10.000000e+00 : f64 %mem = memref.alloc() : memref<1xf64> - memref.store %c2, %mem[%c0] : memref<1xf64> + affine.store %c2, %mem[0] : memref<1xf64> scf.if %c { %mul = arith.mulf %x, %x : f64 - memref.store %mul, %mem[%c0] : memref<1xf64> + affine.store %mul, %mem[0] : memref<1xf64> } %r = affine.load %mem[0] : memref<1xf64> %res = arith.mulf %c2, %r : f64 @@ -80,25 +79,24 @@ module { } // CHECK: @fwddiffeif_then // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { - // CHECK: %[[c0:.+]] = arith.constant 0 : index - // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 - // CHECK: %[[cst_0:.+]] = arith.constant 2.000000e+00 : f64 - // CHECK: %[[cst_1:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK-DAG: %[[cst1:.+]] = arith.constant 1.000000e+01 : f64 // CHECK: %[[alloc:.+]] = memref.alloc() : memref<1xf64> // CHECK: %[[alloc_2:.+]] = memref.alloc() : memref<1xf64> - // CHECK: memref.store %[[cst]], %[[alloc]][%[[c0]]] : memref<1xf64> - // CHECK: memref.store %[[cst_0]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK-DAG: %[[cst0:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: affine.store %[[cst0]], %[[alloc]][0] : memref<1xf64> + // CHECK: affine.store %[[cst2]], %[[alloc_2]][0] : memref<1xf64> // CHECK: scf.if %[[arg2]] { // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK: %[[v5:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK: %[[v6:.+]] = arith.addf %[[v4]], %[[v5]] : f64 // CHECK: %[[v7:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 - // CHECK: memref.store %[[v6]], %[[alloc]][%[[c0]]] : memref<1xf64> - // CHECK: memref.store %[[v7]], %[[alloc_2]][%[[c0]]] : memref<1xf64> + // CHECK: affine.store %[[v6]], %[[alloc]][0] : memref<1xf64> + // CHECK: affine.store %[[v7]], %[[alloc_2]][0] : memref<1xf64> // CHECK: } // CHECK: %[[v0:.+]] = affine.load %[[alloc]][0] : memref<1xf64> // CHECK: %[[v1:.+]] = affine.load %[[alloc_2]][0] : memref<1xf64> - // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst_0]] : f64 - // CHECK: %[[v3:.+]] = arith.mulf %[[cst_0]], %[[v1]] : f64 + // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst2]] : f64 + // CHECK: %[[v3:.+]] = arith.mulf %[[cst2]], %[[v1]] : f64 // CHECK: return %[[v2]] : f64 } diff --git a/enzyme/test/MLIR/ForwardMode/tensorsin.mlir b/enzyme/test/MLIR/ForwardMode/tensorsin.mlir new file mode 100644 index 000000000000..5ab208712b73 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/tensorsin.mlir @@ -0,0 +1,19 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<2xf64>) -> tensor<2xf64> { + %y = math.sin %x : tensor<2xf64> + return %y : tensor<2xf64> + } + func.func @dsq(%x : tensor<2xf64>, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (tensor<2xf64>, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffesquare(%arg0: tensor<2xf64>, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[a0:.+]] = math.cos %arg0 : tensor<2xf64> +// CHECK-NEXT: %[[a1:.+]] = arith.mulf %arg1, %[[a0]] : tensor<2xf64> +// CHECK-NEXT: %[[a2:.+]] = math.sin %arg0 : tensor<2xf64> +// CHECK-NEXT: return %[[a1]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index f3f8f5940d27..c30f8870063e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1398,8 +1398,11 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, insert(dg, next); if (ptree->getArgNameStr(i).size()) { - auto op = - (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + std::string op; + if (intrinsic != MLIRDerivatives) + op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + else + op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); if (prev.size() > 0) { op = "gutils->extractMeta(Builder2, " + op + ", ArrayRef({"; @@ -1891,8 +1894,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, const auto &cfpatterns = recordKeeper.getAllDerivedDefinitions("ControlFlowOp"); - const auto &ropatterns = - recordKeeper.getAllDerivedDefinitions("ReadOnlyIdentityOp"); + const auto &mempatterns = + recordKeeper.getAllDerivedDefinitions("MemoryIdentityOp"); for (auto &pattern : cfpatterns) { auto opName = pattern->getValueAsString("opName"); @@ -1906,13 +1909,14 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << "};\n"; } - for (auto &pattern : ropatterns) { + for (auto &pattern : mempatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); - auto diffargs = pattern->getValueAsListOfInts("diffargs"); - os << "struct " << opName << "ROActivity : \n"; + auto diffargs = pattern->getValueAsListOfInts("ptrargs"); + auto storedargs = pattern->getValueAsListOfInts("storedargs"); + os << "struct " << opName << "MemActivity : \n"; os << " public ActivityOpInterface::ExternalModel<" << opName - << "ROActivity, " << dialect << "::" << opName << "> {\n"; + << "MemActivity, " << dialect << "::" << opName << "> {\n"; os << " bool isInactive(mlir::Operation* op) const {\n"; os << " for (size_t i=0, len=op->getNumOperands(); igetValueAsString("opName"); @@ -1954,13 +1964,16 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " registerAutoDiffUsingControlFlowInterface<" << dialect << "::" << opName << ">(*context);\n"; } - for (Record *pattern : ropatterns) { + for (Record *pattern : mempatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); os << " " << dialect << "::" << opName << "::attachInterface<" << opName - << "ROActivity>(*context);\n"; - os << " registerAutoDiffUsingReadOnlyIdentityInterface<" << dialect - << "::" << opName << ">(*context);\n"; + << "MemActivity>(*context);\n"; + os << " registerAutoDiffUsingMemoryIdentityInterface<" << dialect + << "::" << opName; + for (auto storedarg : pattern->getValueAsListOfInts("storedargs")) + os << ", " << storedarg; + os << ">(*context);\n"; } for (Record *pattern : brpatterns) { auto opName = pattern->getValueAsString("opName"); @@ -1974,6 +1987,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect << "::" << opName << ">(*context);\n"; } + for (Record *pattern : allocpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingAllocationInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } os << "}\n"; } } From d8ec31545895507d2e04e3650c5a912d2f20a52e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 5 Feb 2024 22:00:34 +0100 Subject: [PATCH 006/106] Fix differential use analysis of ivi (#1661) --- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 41 +++++++++++++++++++ .../test/Enzyme/ForwardMode/insertdiffuse.ll | 25 +++++++++++ enzyme/test/lit.site.cfg.py.in | 2 +- 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 4a2cff5d79a9..3be0cdfc2895 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -293,6 +293,47 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( return false; } + if (shadow) { + if (auto IVI = dyn_cast(user)) + if (isa(IVI) || isa(IVI)) { + if (IVI->getOperand(1) == val) { + SmallVector todo; + todo.push_back(IVI); + SmallVector, 1> + done; + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + done.emplace_back(cast(u), cur); + } + } + for (auto &pair : done) { + + bool direct = is_use_directly_needed_in_reverse( + gutils, pair.second, mode, pair.first, oldUnreachable, + QueryType::Shadow, recursiveUse); + if (direct) { + + if (EnzymePrintDiffUse) + llvm::errs() + << " Need (partial) direct " << to_string(qtype) << " of " + << *val << " in reverse from insertelem " << *user + << " via " << *pair.second << " in " << *pair.first << "\n"; + return true; + } + } + } + } + } + if (!shadow) if (auto IEI = dyn_cast(user)) { // Only need the index in the reverse, so if the value is not diff --git a/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll b/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll new file mode 100644 index 000000000000..44f246cdfbe0 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll @@ -0,0 +1,25 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +define { double, i64 } @julia_logabsgamma_3264_inner.1(double %x, i64 %z) { +entry: + %iadd = add i64 %z, 1 + %.fca.0.insert = insertvalue { double, i64 } undef, double %x, 0 + %.fca.1.insert = insertvalue { double, i64 } %.fca.0.insert, i64 %iadd, 1 + ret { double, i64 } %.fca.1.insert +} + +declare { double, i64 } @__enzyme_fwddiff(...) + +define { double, i64 } @ad(double %x, double %dx) { + %m = call { double, i64 } (...) @__enzyme_fwddiff({ double, i64 } (double, i64)* @julia_logabsgamma_3264_inner.1, double %x, double %dx, i64 1) + ret { double, i64 } %m +} + +; CHECK: define internal { double, i64 } @fwddiffejulia_logabsgamma_3264_inner.1(double %x, double %"x'", i64 %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %iadd = add i64 %z, 1 +; CHECK-NEXT: %".fca.0.insert'ipiv" = insertvalue { double, i64 } zeroinitializer, double %"x'", 0 +; CHECK-NEXT: %".fca.1.insert'ipiv" = insertvalue { double, i64 } %".fca.0.insert'ipiv", i64 %iadd, 1 +; CHECK-NEXT: ret { double, i64 } %".fca.1.insert'ipiv" +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 481fce924bcc..0b8a0f831d6e 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -82,7 +82,7 @@ if len("@ENZYME_BINARY_DIR@") == 0: oldPMOP = oldPM newPMOP = newPM -if int(config.llvm_ver) >= 16: +if int(config.llvm_ver) == 16: newPM += " -opaque-pointers=0" oldPM += " -opaque-pointers=0" From 5b6b4ac44a343676aff768dcf0c50ee0a6055fd5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 02:08:17 +0000 Subject: [PATCH 007/106] Fix integration tests on main (#1662) --- enzyme/Enzyme/AdjointGenerator.h | 24 +++++++++++++++++++ .../test/Integration/ReverseMode/forrealloc.c | 8 +++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index c6d558cf5ba0..80d6c851e2a2 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -3353,6 +3353,29 @@ class AdjointGenerator } } } + // If the type is known, but outside of the known range + // (but the memcpy size is a variable), attempt to use + // the first type out of range as the memcpy type. + if (size == 1 && !isa(new_size)) { + for (auto ptr : {orig_dst, orig_src}) { + vd = TR.query(ptr).Data0().ShiftIndices(DL, 0, -1, 0); + if (vd.isKnownPastPointer()) { + ConcreteType mv(BaseType::Unknown); + size_t minInt = 0xFFFFFFFF; + for (const auto &pair : vd.getMapping()) { + if (pair.first.size() != 1) + continue; + if (minInt < (size_t)pair.first[0]) + continue; + minInt = pair.first[0]; + mv = pair.second; + } + assert(mv != BaseType::Unknown); + vd.insert({0}, mv); + goto known; + } + } + } if (errorIfNoType) EmitWarning("CannotDeduceType", MTI, "failed to deduce type of copy ", MTI); @@ -3368,6 +3391,7 @@ class AdjointGenerator &TR.analyzer, nullptr, wrap(&BuilderZ)); } else { ss << "\n"; + ss << *gutils->oldFunc << "\n"; TR.dump(ss); EmitFailure("CannotDeduceType", MTI.getDebugLoc(), &MTI, ss.str()); } diff --git a/enzyme/test/Integration/ReverseMode/forrealloc.c b/enzyme/test/Integration/ReverseMode/forrealloc.c index 9fc9ec5798d7..9034ff68f0bb 100644 --- a/enzyme/test/Integration/ReverseMode/forrealloc.c +++ b/enzyme/test/Integration/ReverseMode/forrealloc.c @@ -1,11 +1,11 @@ // RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - // RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - +// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - #include #include From dd02ac51a56d8761d87d4dd3ebef6f730d88e37f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 11:01:05 +0000 Subject: [PATCH 008/106] Fix DT error (#1665) --- enzyme/Enzyme/FunctionUtils.cpp | 28 +++++++++++------ enzyme/test/Enzyme/ReverseMode/logabsgamma.ll | 31 +++++++++++++++++++ .../test/Integration/ReverseMode/forrealloc.c | 16 +++++----- .../ReverseMode/integrateconst.cpp | 1 - enzyme/test/Integration/ReverseMode/sret.cpp | 3 -- 5 files changed, 57 insertions(+), 22 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/logabsgamma.ll diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 4c3370cf4cf9..71ceb9688eea 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2495,19 +2495,23 @@ void ReplaceFunctionImplementation(Module &M) { } void PreProcessCache::optimizeIntermediate(Function *F) { - PromotePass().run(*F, FAM); + PreservedAnalyses PA; + PA = PromotePass().run(*F, FAM); + FAM.invalidate(*F, PA); #if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - GVNPass().run(*F, FAM); + PA = GVNPass().run(*F, FAM); #else - GVN().run(*F, FAM); + PA = GVN().run(*F, FAM); #endif + FAM.invalidate(*F, PA); #if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG) - SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM); + PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM); #elif LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - SROAPass().run(*F, FAM); + PA = SROAPass().run(*F, FAM); #else - SROA().run(*F, FAM); + PA = SROA().run(*F, FAM); #endif + FAM.invalidate(*F, PA); if (EnzymeSelectOpt) { #if LLVM_VERSION_MAJOR >= 12 @@ -2518,8 +2522,10 @@ void PreProcessCache::optimizeIntermediate(Function *F) { /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr); #endif - SimplifyCFGPass(scfgo).run(*F, FAM); - CorrelatedValuePropagationPass().run(*F, FAM); + PA = SimplifyCFGPass(scfgo).run(*F, FAM); + FAM.invalidate(*F, PA); + PA = CorrelatedValuePropagationPass().run(*F, FAM); + FAM.invalidate(*F, PA); SelectOptimization(F); } // EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM); @@ -2529,8 +2535,10 @@ void PreProcessCache::optimizeIntermediate(Function *F) { ReplaceFunctionImplementation(*F->getParent()); - PreservedAnalyses PA; - FAM.invalidate(*F, PA); + { + PreservedAnalyses PA; + FAM.invalidate(*F, PA); + } #if LLVM_VERSION_MAJOR < 14 using OptimizationLevel = llvm::PassBuilder::OptimizationLevel; diff --git a/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll new file mode 100644 index 000000000000..02e0f40cd5bc --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s + +; XFAIL: * + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %a = call { double, i64 } @logabsgamma(double %x) + %b = extractvalue { double, i64 } %a, 0 + ret double %b +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +declare { double, i64 } @logabsgamma(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @cosh(double %x) +; CHECK-NEXT: %1 = fmul fast double %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0 +; CHECK-NEXT: ret { double } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Integration/ReverseMode/forrealloc.c b/enzyme/test/Integration/ReverseMode/forrealloc.c index 9034ff68f0bb..924f7f04a9f1 100644 --- a/enzyme/test/Integration/ReverseMode/forrealloc.c +++ b/enzyme/test/Integration/ReverseMode/forrealloc.c @@ -1,11 +1,11 @@ -// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S -enzyme-loose-types | %lli - -// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S -enzyme-loose-types | %lli - +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %loadClangEnzyme %s -S -emit-llvm -o - | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %loadClangEnzyme %s -S -emit-llvm -o - | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi #include #include diff --git a/enzyme/test/Integration/ReverseMode/integrateconst.cpp b/enzyme/test/Integration/ReverseMode/integrateconst.cpp index 2086ab8b525d..c55247d65e55 100644 --- a/enzyme/test/Integration/ReverseMode/integrateconst.cpp +++ b/enzyme/test/Integration/ReverseMode/integrateconst.cpp @@ -14,7 +14,6 @@ #define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS #define BOOST_NO_EXCEPTIONS -#include #include #include diff --git a/enzyme/test/Integration/ReverseMode/sret.cpp b/enzyme/test/Integration/ReverseMode/sret.cpp index 2b5315d5d35c..16751fe3e318 100644 --- a/enzyme/test/Integration/ReverseMode/sret.cpp +++ b/enzyme/test/Integration/ReverseMode/sret.cpp @@ -8,9 +8,6 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S #include "../test_utils.h" -#include -#include -#include typedef struct { double df[3]; From cf728efae5bab49fead2bd74c891361f73f408ca Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 11:29:56 +0000 Subject: [PATCH 009/106] Fixups for external bazel integration tests (#1668) --- enzyme/BUILD | 184 +++++++++++++----- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp | 28 +-- .../Analysis/DataFlowActivityAnalysis.cpp | 27 +-- .../LinalgAutoDiffOpInterfaceImpl.cpp | 1 - .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 10 +- .../Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp | 2 +- enzyme/Enzyme/TypeAnalysis/BaseType.h | 8 +- enzyme/test/Integration/ForwardMode/loops.c | 4 - .../Integration/ForwardMode/loopsdouble.c | 4 - .../Integration/ForwardMode/loopstriple.c | 4 - enzyme/test/Integration/ForwardMode/rwrloop.c | 4 - enzyme/test/Integration/ForwardMode/sumtil.c | 4 - enzyme/test/Integration/ForwardMode/sumtil2.c | 4 - .../Integration/ForwardModeVector/binops.c | 17 +- .../ReverseMode/allocatedtape_err.c | 5 +- .../test/Integration/ReverseMode/boundissue.c | 4 - .../test/Integration/ReverseMode/cachefwd.c | 4 - .../Integration/ReverseMode/customcombined.c | 2 +- enzyme/test/Integration/ReverseMode/dbginfo.c | 2 - .../ReverseMode/differential_pointer_return.c | 5 - .../test/Integration/ReverseMode/forrealloc.c | 6 - enzyme/test/Integration/ReverseMode/frexp.c | 6 +- .../test/Integration/ReverseMode/fwdsolve.c | 9 +- .../ReverseMode/gradient-struct-return.c | 1 - .../Integration/ReverseMode/headerremat.c | 7 - .../Integration/ReverseMode/insertsort_sum.c | 4 - .../ReverseMode/insertsort_sum_alt.c | 5 - .../ReverseMode/insertsort_sum_min.c | 4 - enzyme/test/Integration/ReverseMode/loops.c | 4 - .../Integration/ReverseMode/loopsdouble.c | 4 - .../Integration/ReverseMode/loopstriple.c | 4 - enzyme/test/Integration/ReverseMode/manydiv.c | 4 - enzyme/test/Integration/ReverseMode/manymax.c | 4 - .../test/Integration/ReverseMode/metamalloc.c | 10 +- enzyme/test/Integration/ReverseMode/metarwr.c | 6 +- .../ReverseMode/mixedstruct1-old.c | 5 - .../ReverseMode/mixedstruct1-simple.c | 7 +- .../ReverseMode/mixedstruct1-simplefda.c | 5 - .../ReverseMode/mixedstruct1-simpleps.c | 5 - .../ReverseMode/mixedstruct1-simpler.c | 5 - .../ReverseMode/mixedstruct1-simplest.c | 5 - .../Integration/ReverseMode/mixedstruct1-sp.c | 5 - .../Integration/ReverseMode/mixedstruct1.c | 5 - .../Integration/ReverseMode/multivecmaxC.c | 5 +- .../Integration/ReverseMode/posix_memalign.c | 8 +- .../ReverseMode/posix_memalignfor.c | 8 +- .../Integration/ReverseMode/readwriteread.c | 5 - enzyme/test/Integration/ReverseMode/recurse.c | 4 - enzyme/test/Integration/ReverseMode/remat.c | 4 - .../Integration/ReverseMode/rematSimple.c | 4 - enzyme/test/Integration/ReverseMode/rwrloop.c | 5 - enzyme/test/Integration/ReverseMode/rwrmeta.c | 4 - .../Integration/ReverseMode/smallrealloc.c | 5 - .../Integration/ReverseMode/subdoublestore.c | 6 - enzyme/test/Integration/ReverseMode/sumtil.c | 4 - enzyme/test/Integration/ReverseMode/sumtil2.c | 4 - .../test/Integration/ReverseMode/taylorlog.c | 2 - enzyme/test/Integration/blas_inline.h | 3 + enzyme/test/Integration/test_utils.h | 23 ++- enzyme/test/MLIR/ForwardMode/inactive.mlir | 8 +- 60 files changed, 212 insertions(+), 333 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index 3b4d03c001e0..0ad691e5aae0 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -1,6 +1,6 @@ +load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@llvm-project//llvm:tblgen.bzl", "gentbl") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@bazel_skylib//rules:expand_template.bzl", "expand_template") licenses(["notice"]) @@ -145,11 +145,14 @@ gentbl( cc_library( name = "EnzymeStatic", - srcs = glob([ - "Enzyme/*.cpp", - "Enzyme/TypeAnalysis/*.cpp", - "Enzyme/Clang/EnzymeClang.cpp", - ], exclude=["Enzyme/eopt.cpp"]), + srcs = glob( + [ + "Enzyme/*.cpp", + "Enzyme/TypeAnalysis/*.cpp", + "Enzyme/Clang/EnzymeClang.cpp", + ], + exclude = ["Enzyme/eopt.cpp"], + ), hdrs = glob([ "Enzyme/*.h", "Enzyme/TypeAnalysis/*.h", @@ -194,7 +197,7 @@ cc_library( "@llvm-project//llvm:TransformUtils", "@llvm-project//llvm:config", ], - alwayslink = 1 + alwayslink = 1, ) cc_binary( @@ -223,6 +226,8 @@ cc_binary( srcs = ["Enzyme/eopt.cpp"], deps = [ ":EnzymeStatic", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", "@llvm-project//llvm:opt-driver", ], ) @@ -230,16 +235,16 @@ cc_binary( td_library( name = "EnzymeDialectTdFiles", srcs = [ - "Enzyme/MLIR/Dialect/Dialect.td", + "Enzyme/MLIR/Dialect/Dialect.td", ], - deps = [ + deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", "@llvm-project//mlir:ViewLikeInterfaceTdFiles", - "@llvm-project//mlir:FunctionInterfacesTdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:LoopLikeInterfaceTdFiles", - ] + ], ) gentbl_cc_library( @@ -277,9 +282,9 @@ td_library( name = "EnzymePassesTdFiles", srcs = [ ], - deps = [ + deps = [ "@llvm-project//mlir:PassBaseTdFiles", - ] + ], ) gentbl_cc_library( @@ -349,7 +354,6 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) - gentbl_cc_library( name = "EnzymeTypeInterfacesIncGen", tbl_outs = [ @@ -394,7 +398,8 @@ gentbl( td_file = "Enzyme/MLIR/Implementations/AffineDerivatives.td", td_srcs = [ "Enzyme/MLIR/Implementations/AffineDerivatives.td", - "Enzyme/MLIR/Implementations/Common.td"], + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -408,7 +413,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/ArithDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -422,7 +430,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/LLVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/LLVMDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -436,7 +447,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/NVVMDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/NVVMDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -450,7 +464,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/SCFDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/SCFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/SCFDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -464,7 +481,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/CFDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/CFDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/CFDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -478,7 +498,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/MemRefDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/MemRefDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/MemRefDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -492,7 +515,10 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/MathDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/MathDerivatives.td", "Enzyme/MLIR/Implementations/Common.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/MathDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -514,48 +540,95 @@ cc_library( "Enzyme/MLIR/Analysis/*.h", "Enzyme/MLIR/Implementations/*.h", "Enzyme/Utils.h", - "Enzyme/TypeAnalysis/*.h" + "Enzyme/TypeAnalysis/*.h", ]), - includes = ["Enzyme/MLIR", "Enzyme"], + includes = [ + "Enzyme", + "Enzyme/MLIR", + ], visibility = ["//visibility:public"], deps = [ + ":EnzymeAttributesIncGen", + ":EnzymeEnumsIncGen", + ":EnzymeOpInterfacesIncGen", + ":EnzymeOpsIncGen", + ":EnzymePassesIncGen", + ":EnzymeTypeInterfacesIncGen", + ":EnzymeTypesIncGen", ":affine-derivatives", ":arith-derivatives", + ":cf-derivatives", ":llvm-derivatives", + ":math-derivatives", + ":memref-derivatives", ":nvvm-derivatives", ":scf-derivatives", - ":cf-derivatives", - ":memref-derivatives", - ":math-derivatives", - ":EnzymeOpsIncGen", - ":EnzymePassesIncGen", - ":EnzymeTypesIncGen", - ":EnzymeEnumsIncGen", - ":EnzymeAttributesIncGen", - ":EnzymeTypeInterfacesIncGen", - ":EnzymeOpInterfacesIncGen", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//llvm:config", "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:ConversionPasses", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:CastInterfaces", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgStructuredOpsIncGen", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", ], ) cc_binary( name = "enzymemlir-opt", srcs = ["Enzyme/MLIR/enzymemlir-opt.cpp"], - visibility = ["//visibility:public"], includes = ["Enzyme/MLIR"], + visibility = ["//visibility:public"], deps = [ ":EnzymeMLIR", - "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Transforms", ], ) @@ -586,18 +659,25 @@ expand_template( name = "%s.test" % src, srcs = [src], data = [ + ":enzyme-clang", + ":enzyme-clang++", + ":enzyme-opt", + ":enzymemlir-opt", ":test/lit.cfg.py", ":test/lit.site.cfg.py", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:count", - "@llvm-project//llvm:not", - "@llvm-project//llvm:lli", - ":enzyme-opt", "@llvm-project//clang:builtin_headers_gen", - ":enzyme-clang", - ":enzyme-clang++", - ":enzymemlir-opt" - ] + glob(["test/**/*.h"]) + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["test/**/*.h"]), + ) + for src in glob( + [ + "test/**/*.mlir", + "test/Integration/**/*.c", + "test/Integration/**/.cpp", + ], + exclude = ["test/**/*omp*.c"], ) - for src in glob(["test/**/*.mlir", "test/Integration/**/*.c", "test/Integration/**/.cpp"], exclude=["test/**/*omp*.c"]) ] diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index b584e8e64a1c..767a60a3ed2b 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -211,7 +211,7 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, // TODO: consider a stricter check that we only replace unknown // values or a value with itself, currently blocked by memalign. AliasClassSet valuesCopy(values); - valuesCopy.join(it->getSecond()); + (void)valuesCopy.join(it->getSecond()); values.print(llvm::errs()); llvm::errs() << "\n"; it->getSecond().print(llvm::errs()); @@ -572,7 +572,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( if (funcMayReadOther) { // If a function may read from other, it may be storing pointers from // unknown alias sets into any writable pointer. - functionMayCapture.markUnknown(); + (void)functionMayCapture.markUnknown(); } else { for (int pointerAsData : pointerLikeOperands) { // If not captured, it cannot be stored in anything. @@ -583,7 +583,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( const auto *srcClasses = getOrCreateFor( call, call.getArgOperands()[pointerAsData]); - functionMayCapture.join(srcClasses->getAliasClassesObject()); + (void)functionMayCapture.join(srcClasses->getAliasClassesObject()); } } @@ -598,16 +598,17 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // If the argument cannot be stored into, just preserve it as is. if (!mayWriteArg(callee, pointerOperand, argModRef)) { - nonWritableOperandClasses.join(destClasses->getAliasClassesObject()); + (void)nonWritableOperandClasses.join( + destClasses->getAliasClassesObject()); continue; } - writableClasses.join(destClasses->getAliasClassesObject()); + (void)writableClasses.join(destClasses->getAliasClassesObject()); // If the destination class is unknown, mark all known classes // pessimistic (alias classes that have not beed analyzed and thus are // absent from pointsTo are treated as "undefined" at this point). if (destClasses->isUnknown()) { - writableClasses.markUnknown(); + (void)writableClasses.markUnknown(); changed |= after->markAllPointToUnknown(); break; } @@ -675,15 +676,15 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( AliasClassSet resultWithoutNonWritableOperands = AliasClassSet::getUndefined(); if (destClasses->isUnknown() || nonWritableOperandClasses.isUnknown()) { - resultWithoutNonWritableOperands.markUnknown(); + (void)resultWithoutNonWritableOperands.markUnknown(); } else if (!destClasses->isUndefined() && !nonWritableOperandClasses.isUndefined()) { DenseSet nonOperandClasses = llvm::set_difference(destClasses->getAliasClasses(), nonWritableOperandClasses.getAliasClasses()); - resultWithoutNonWritableOperands.insert(nonOperandClasses); + (void)resultWithoutNonWritableOperands.insert(nonOperandClasses); } else { - resultWithoutNonWritableOperands.join( + (void)resultWithoutNonWritableOperands.join( destClasses->getAliasClassesObject()); } @@ -973,7 +974,7 @@ void enzyme::AliasAnalysis::visitOperation( if (!isPointerLike(result.getType())) continue; - results[result.getResultNumber()]->markUnknown(); + (void)results[result.getResultNumber()]->markUnknown(); } return; } @@ -1027,7 +1028,7 @@ void enzyme::AliasAnalysis::visitExternalCall( continue; const AliasClassLattice *srcClasses = operands[operandNo]; - operandAliasClasses.join(srcClasses->getAliasClassesObject()); + (void)operandAliasClasses.join(srcClasses->getAliasClassesObject()); if (!mayReadArg(callee, operandNo, argModRef)) continue; @@ -1035,13 +1036,14 @@ void enzyme::AliasAnalysis::visitExternalCall( // If can read from argument, collect the alias classes that can this // argument may be pointing to. const auto *pointsToLattice = getOrCreateFor(call, call); - srcClasses->getAliasClassesObject().foreachClass( + (void)srcClasses->getAliasClassesObject().foreachClass( [&](DistinctAttr srcClass, AliasClassSet::State state) { // Nothing to do in top/bottom case. In the top case, we have already // set `operandAliasClasses` to top above. if (srcClass == nullptr) return ChangeResult::NoChange; - operandAliasClasses.join(pointsToLattice->getPointsTo(srcClass)); + (void)operandAliasClasses.join( + pointsToLattice->getPointsTo(srcClass)); return ChangeResult::NoChange; }); } diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 4f922d6eb5d5..9c540064a49a 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -330,7 +330,7 @@ class MemoryActivity : public AbstractDenseLattice { const MemoryActivityState *rhsActivity = isKnownInRHS ? &rhsIt->getSecond() : &rhs.otherMemoryActivity; MemoryActivityState updatedActivity(*lhsActivity); - updatedActivity.merge(*rhsActivity); + (void)updatedActivity.merge(*rhsActivity); if ((lhsIt != activityStates.end() && updatedActivity != lhsIt->getSecond()) || (lhsIt == activityStates.end() && @@ -490,7 +490,7 @@ std::optional getCopySource(Operation *op) { /// If the classes are undefined, the callback will not be called at all. void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass, function_ref forEachFn) { - ptrAliasClass->getAliasClassesObject().foreachClass( + (void)ptrAliasClass->getAliasClassesObject().foreachClass( [&](DistinctAttr alloc, enzyme::AliasClassSet::State state) { if (state != enzyme::AliasClassSet::State::Undefined) forEachFn(alloc); @@ -854,15 +854,16 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, std::deque frontier; DenseSet visited; auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) { - aliasClasses.foreachClass([&](DistinctAttr neighbor, - enzyme::AliasClassSet::State state) { - assert(neighbor && "unhandled undefined/unknown case before visit"); - if (!visited.contains(neighbor)) { - visited.insert(neighbor); - frontier.push_back(neighbor); - } - return ChangeResult::NoChange; - }); + (void)aliasClasses.foreachClass( + [&](DistinctAttr neighbor, enzyme::AliasClassSet::State state) { + assert(neighbor && + "unhandled undefined/unknown case before visit"); + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + frontier.push_back(neighbor); + } + return ChangeResult::NoChange; + }); }; // If this triggers, investigate why the alias classes weren't computed. @@ -1071,7 +1072,7 @@ void enzyme::runDataFlowActivityAnalysis( // analyses, enzyme_const is the default. if (activity == enzyme::Activity::enzyme_out) { auto *argLattice = solver.getOrCreateState(arg); - argLattice->join(ValueActivity::getActiveVal()); + (void)argLattice->join(ValueActivity::getActiveVal()); } } @@ -1086,7 +1087,7 @@ void enzyme::runDataFlowActivityAnalysis( solver.getOrCreateState(operand); // Very basic type inference of the type if (isa(operand.getType())) { - returnLattice->meet(ValueActivity::getActiveVal()); + (void)returnLattice->meet(ValueActivity::getActiveVal()); } } } diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 93488a07fd6e..334d7325d7ec 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -28,7 +28,6 @@ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 20ed156d312d..467a8f59ec6e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -36,14 +36,12 @@ mlir::enzyme::MGradientUtils::MGradientUtils( ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp) - : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), TA(TA_), - TR(TR_), omp(omp), blocksNotForAnalysis(), + : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), + invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_), + originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - width(width), ArgDiffeTypes(ArgDiffeTypes_), - originalToNewFn(originalToNewFn_), - originalToNewFnOps(originalToNewFnOps_), - invertedPointers(invertedPointers_) { + TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) { /* for (BasicBlock &BB : *oldFunc) { diff --git a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp index 99a41f804593..c2680ae6db72 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp @@ -91,7 +91,7 @@ struct PrintAliasAnalysisPass continue; // TODO(zinenko): this has been overriding the argument... // Use an array attr instead (will break syntactic tests). - state->getAliasClassesObject().foreachClass( + (void)state->getAliasClassesObject().foreachClass( [&](DistinctAttr aliasClass, enzyme::AliasClassSet::State state) { if (state == enzyme::AliasClassSet::State::Undefined) funcOp.setArgAttr( diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index 9ba5bd2bec5c..f73948a703ad 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -25,8 +25,6 @@ #ifndef ENZYME_TYPE_ANALYSIS_BASE_TYPE_H #define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1 -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/ErrorHandling.h" #include /// Categories of potential types @@ -57,11 +55,11 @@ static inline std::string to_string(BaseType t) { case BaseType::Unknown: return "Unknown"; } - llvm_unreachable("unknown inttype"); + assert(0 && "unknown inttype"); } /// Convert string to BaseType -static inline BaseType parseBaseType(llvm::StringRef str) { +template static inline BaseType parseBaseType(T str) { if (str == "Integer") return BaseType::Integer; if (str == "Float") @@ -72,6 +70,6 @@ static inline BaseType parseBaseType(llvm::StringRef str) { return BaseType::Anything; if (str == "Unknown") return BaseType::Unknown; - llvm_unreachable("Unknown BaseType string"); + assert(0 && "Unknown BaseType string"); } #endif diff --git a/enzyme/test/Integration/ForwardMode/loops.c b/enzyme/test/Integration/ForwardMode/loops.c index 612839248cc2..33bef07a2d16 100644 --- a/enzyme/test/Integration/ForwardMode/loops.c +++ b/enzyme/test/Integration/ForwardMode/loops.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopsdouble.c b/enzyme/test/Integration/ForwardMode/loopsdouble.c index 8c34740c80d5..4faa0cad74f3 100644 --- a/enzyme/test/Integration/ForwardMode/loopsdouble.c +++ b/enzyme/test/Integration/ForwardMode/loopsdouble.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopstriple.c b/enzyme/test/Integration/ForwardMode/loopstriple.c index 060e74d009f7..84e98ab1694e 100644 --- a/enzyme/test/Integration/ForwardMode/loopstriple.c +++ b/enzyme/test/Integration/ForwardMode/loopstriple.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/rwrloop.c b/enzyme/test/Integration/ForwardMode/rwrloop.c index dc71b0eead20..32e548029f6d 100644 --- a/enzyme/test/Integration/ForwardMode/rwrloop.c +++ b/enzyme/test/Integration/ForwardMode/rwrloop.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/sumtil.c b/enzyme/test/Integration/ForwardMode/sumtil.c index 9e9286f97157..5d6369df4909 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil.c +++ b/enzyme/test/Integration/ForwardMode/sumtil.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardMode/sumtil2.c b/enzyme/test/Integration/ForwardMode/sumtil2.c index 32d289703e69..18428ca371d6 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil2.c +++ b/enzyme/test/Integration/ForwardMode/sumtil2.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardModeVector/binops.c b/enzyme/test/Integration/ForwardModeVector/binops.c index 27c786a86980..a3b66f27ea31 100644 --- a/enzyme/test/Integration/ForwardModeVector/binops.c +++ b/enzyme/test/Integration/ForwardModeVector/binops.c @@ -7,10 +7,7 @@ // RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include +#include "../test_utils.h" /* #ifdef __cplusplus @@ -28,18 +25,6 @@ threshold) { if (fabs(f1-f2) > threshold) return false; return true; #endif */ -#define APPROX_EQ(LHS, RHS, THRES) \ - { \ - if (__builtin_fabs(LHS - RHS) > THRES) { \ - fprintf(stderr, \ - "Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d " \ - "(%s)\n", \ - #LHS, LHS, #RHS, RHS, THRES, __FILE__, __LINE__, \ - __PRETTY_FUNCTION__); \ - abort(); \ - } \ - }; - typedef struct { double dx, dy; diff --git a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c index 1236a86388e7..27bd89f36e40 100644 --- a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c +++ b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c @@ -7,8 +7,9 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi -#include -#include "../test_utils.h" +extern int enzyme_allocated; +extern int enzyme_tape; +double sin(double); void __enzyme_reverse(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/boundissue.c b/enzyme/test/Integration/ReverseMode/boundissue.c index f297ebbdda97..4d41f2bf482a 100644 --- a/enzyme/test/Integration/ReverseMode/boundissue.c +++ b/enzyme/test/Integration/ReverseMode/boundissue.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/cachefwd.c b/enzyme/test/Integration/ReverseMode/cachefwd.c index ab56e3dd0853..6e1b058a7004 100644 --- a/enzyme/test/Integration/ReverseMode/cachefwd.c +++ b/enzyme/test/Integration/ReverseMode/cachefwd.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/customcombined.c b/enzyme/test/Integration/ReverseMode/customcombined.c index 5ce2c1df2b07..cbd86d0452ef 100644 --- a/enzyme/test/Integration/ReverseMode/customcombined.c +++ b/enzyme/test/Integration/ReverseMode/customcombined.c @@ -32,7 +32,7 @@ void* augment_square_(const double* src, const double *d_src, double* dest, doub // intentionally incorrect for debugging *dest = 7.0; *d_dest = 11.0; - return NULL; + return (void*)0; } int gradient = 0; diff --git a/enzyme/test/Integration/ReverseMode/dbginfo.c b/enzyme/test/Integration/ReverseMode/dbginfo.c index 06767e624849..31d9bafe4439 100644 --- a/enzyme/test/Integration/ReverseMode/dbginfo.c +++ b/enzyme/test/Integration/ReverseMode/dbginfo.c @@ -7,8 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - -g | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - -g | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -//#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c index b05a7264ae3c..42daecbaa7ed 100644 --- a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c +++ b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/forrealloc.c b/enzyme/test/Integration/ReverseMode/forrealloc.c index 924f7f04a9f1..5daf309c6c6a 100644 --- a/enzyme/test/Integration/ReverseMode/forrealloc.c +++ b/enzyme/test/Integration/ReverseMode/forrealloc.c @@ -7,14 +7,8 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi -#include -#include -#include -#include - #include "../test_utils.h" - float __enzyme_autodiff(void*, float, int); float foo(float inp, int n) { diff --git a/enzyme/test/Integration/ReverseMode/frexp.c b/enzyme/test/Integration/ReverseMode/frexp.c index 4b917e7e282f..08858d51f796 100644 --- a/enzyme/test/Integration/ReverseMode/frexp.c +++ b/enzyme/test/Integration/ReverseMode/frexp.c @@ -7,12 +7,10 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" +extern double frexp ( double num, int* exp ); + double f(double x) { int exp; return frexp(x, &exp); diff --git a/enzyme/test/Integration/ReverseMode/fwdsolve.c b/enzyme/test/Integration/ReverseMode/fwdsolve.c index c7a6ad63040e..5a7ab4cc67e8 100644 --- a/enzyme/test/Integration/ReverseMode/fwdsolve.c +++ b/enzyme/test/Integration/ReverseMode/fwdsolve.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" @@ -25,10 +20,10 @@ void forward_sub(int N, double* __restrict__ L, double * __restrict__ b, double b must be a vector of the same leading dimension as L """ */ - for (size_t i=0; i 1) - for (size_t j=0; j #include "../test_utils.h" typedef struct { diff --git a/enzyme/test/Integration/ReverseMode/headerremat.c b/enzyme/test/Integration/ReverseMode/headerremat.c index b3397ad8bead..a324a519557f 100644 --- a/enzyme/test/Integration/ReverseMode/headerremat.c +++ b/enzyme/test/Integration/ReverseMode/headerremat.c @@ -7,15 +7,8 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" - -#include - __attribute__((noinline)) int evaluate_integrand(const int nr, const int dtheta) diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum.c b/enzyme/test/Integration/ReverseMode/insertsort_sum.c index 7e7d6b0a3ae4..7b42f7613890 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum.c @@ -6,10 +6,6 @@ // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c index 3a643492d30f..321438af8b9a 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c index b62150353581..f6bfd1a4c5c0 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c @@ -8,10 +8,6 @@ // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include "../test_utils.h" -#include -#include -#include -#include #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loops.c b/enzyme/test/Integration/ReverseMode/loops.c index 3a2794963eb5..57746125be68 100644 --- a/enzyme/test/Integration/ReverseMode/loops.c +++ b/enzyme/test/Integration/ReverseMode/loops.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loopsdouble.c b/enzyme/test/Integration/ReverseMode/loopsdouble.c index 8108a9813c63..c9ffc8b6bde6 100644 --- a/enzyme/test/Integration/ReverseMode/loopsdouble.c +++ b/enzyme/test/Integration/ReverseMode/loopsdouble.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loopstriple.c b/enzyme/test/Integration/ReverseMode/loopstriple.c index 490502145810..a1aa76444893 100644 --- a/enzyme/test/Integration/ReverseMode/loopstriple.c +++ b/enzyme/test/Integration/ReverseMode/loopstriple.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/manydiv.c b/enzyme/test/Integration/ReverseMode/manydiv.c index a0da1f6499b3..e3937c6bb956 100644 --- a/enzyme/test/Integration/ReverseMode/manydiv.c +++ b/enzyme/test/Integration/ReverseMode/manydiv.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/manymax.c b/enzyme/test/Integration/ReverseMode/manymax.c index 90b61041ec1e..4cf31bf1cfa7 100644 --- a/enzyme/test/Integration/ReverseMode/manymax.c +++ b/enzyme/test/Integration/ReverseMode/manymax.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/metamalloc.c b/enzyme/test/Integration/ReverseMode/metamalloc.c index 96a848fd9698..3765a0c803fd 100644 --- a/enzyme/test/Integration/ReverseMode/metamalloc.c +++ b/enzyme/test/Integration/ReverseMode/metamalloc.c @@ -7,20 +7,16 @@ // RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); struct { int count; -void* (*allocfn)(size_t); +void* (*allocfn)(unsigned long); } tup = {0, malloc}; __attribute__((noinline)) -void* metamalloc(size_t size) { +void* metamalloc(unsigned long size) { void* ret = tup.allocfn(size); //if (ret != 0) // tup.count++; @@ -38,7 +34,7 @@ double alldiv(double x) { } -static void* (*sallocfn)(size_t) = malloc; +static void* (*sallocfn)(unsigned long) = malloc; __attribute__((noinline)) void* smetamalloc(int size) { return sallocfn(size); diff --git a/enzyme/test/Integration/ReverseMode/metarwr.c b/enzyme/test/Integration/ReverseMode/metarwr.c index 4efcd475f68f..3e3310c634bd 100644 --- a/enzyme/test/Integration/ReverseMode/metarwr.c +++ b/enzyme/test/Integration/ReverseMode/metarwr.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); @@ -19,7 +15,7 @@ void call(double* __restrict__ a, long** data) { long* segment = data[0]; long size = segment[1] - segment[0]; printf("seg[1]=%d seg[0]=%d\n", segment[1], segment[0]); - for (size_t i=0; i -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c index 5dea68e793da..40444dd8306f 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c @@ -5,12 +5,7 @@ // RUN: %clang -std=c11 %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - - -#include -#include -#include -#include +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - q #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c index e2b3f9788fca..22e6fc4dedc7 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c index b4857429b6fe..7de76117e765 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c index 61b81d0a57bf..6c602b32da3a 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c index bcf3ea958485..5417fce6b8ea 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c index eb783098cf02..d64e9a70da9a 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1.c b/enzyme/test/Integration/ReverseMode/mixedstruct1.c index 3ab79b584ef9..e1218dbc0541 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/multivecmaxC.c b/enzyme/test/Integration/ReverseMode/multivecmaxC.c index c39559adb2de..460036659ffd 100644 --- a/enzyme/test/Integration/ReverseMode/multivecmaxC.c +++ b/enzyme/test/Integration/ReverseMode/multivecmaxC.c @@ -10,9 +10,6 @@ // RUN: %clang++ -ffast-math -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -fno-exceptions %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang++ -ffast-math -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -fno-exceptions %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); @@ -21,7 +18,7 @@ extern void __enzyme_autodiff(void*, double*, double*, int); }*/ double reduce_max(double* vec, int size) { - double ret = -INFINITY; + double ret = -10000000; double *maxes = (double*)malloc(sizeof(double)*size); int count = 0; for (int i = 0; i < size; i++) { diff --git a/enzyme/test/Integration/ReverseMode/posix_memalign.c b/enzyme/test/Integration/ReverseMode/posix_memalign.c index 48aab315b95a..b50a405dd70c 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalign.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalign.c @@ -7,15 +7,9 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" -int posix_memalign(void **memptr, size_t alignment, size_t size); +int posix_memalign(void **memptr, unsigned long alignment, unsigned long size); float __enzyme_autodiff(void*, float, int); diff --git a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c index 0a01ef095142..7336444421a8 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c @@ -7,15 +7,9 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" -int posix_memalign(void **memptr, size_t alignment, size_t size); +int posix_memalign(void **memptr, unsigned long alignment, unsigned long size); float __enzyme_autodiff(void*, float, int); diff --git a/enzyme/test/Integration/ReverseMode/readwriteread.c b/enzyme/test/Integration/ReverseMode/readwriteread.c index ecfdce54d27e..adb5afca594a 100644 --- a/enzyme/test/Integration/ReverseMode/readwriteread.c +++ b/enzyme/test/Integration/ReverseMode/readwriteread.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/recurse.c b/enzyme/test/Integration/ReverseMode/recurse.c index a53041e77188..ca06c745411f 100644 --- a/enzyme/test/Integration/ReverseMode/recurse.c +++ b/enzyme/test/Integration/ReverseMode/recurse.c @@ -9,10 +9,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/remat.c b/enzyme/test/Integration/ReverseMode/remat.c index b228bcc13c99..c9b771b17d1b 100644 --- a/enzyme/test/Integration/ReverseMode/remat.c +++ b/enzyme/test/Integration/ReverseMode/remat.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -// test.c -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rematSimple.c b/enzyme/test/Integration/ReverseMode/rematSimple.c index 7b95f9cb9506..788e0c258c8e 100644 --- a/enzyme/test/Integration/ReverseMode/rematSimple.c +++ b/enzyme/test/Integration/ReverseMode/rematSimple.c @@ -3,10 +3,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - ; fi -// test.c -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rwrloop.c b/enzyme/test/Integration/ReverseMode/rwrloop.c index cdf9e3774553..74b9acb7b897 100644 --- a/enzyme/test/Integration/ReverseMode/rwrloop.c +++ b/enzyme/test/Integration/ReverseMode/rwrloop.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rwrmeta.c b/enzyme/test/Integration/ReverseMode/rwrmeta.c index 34a15d5c93d4..985ccdd44f1a 100644 --- a/enzyme/test/Integration/ReverseMode/rwrmeta.c +++ b/enzyme/test/Integration/ReverseMode/rwrmeta.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/smallrealloc.c b/enzyme/test/Integration/ReverseMode/smallrealloc.c index 51e29ccdfce5..1244a2f12cf1 100644 --- a/enzyme/test/Integration/ReverseMode/smallrealloc.c +++ b/enzyme/test/Integration/ReverseMode/smallrealloc.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/subdoublestore.c b/enzyme/test/Integration/ReverseMode/subdoublestore.c index f411d98f3e7f..130f79c99979 100644 --- a/enzyme/test/Integration/ReverseMode/subdoublestore.c +++ b/enzyme/test/Integration/ReverseMode/subdoublestore.c @@ -7,12 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/sumtil.c b/enzyme/test/Integration/ReverseMode/sumtil.c index 0a2b0502c2bc..f2a34b228afe 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil.c +++ b/enzyme/test/Integration/ReverseMode/sumtil.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/sumtil2.c b/enzyme/test/Integration/ReverseMode/sumtil2.c index aac316c7c4ea..85cea6c94095 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil2.c +++ b/enzyme/test/Integration/ReverseMode/sumtil2.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/taylorlog.c b/enzyme/test/Integration/ReverseMode/taylorlog.c index 649dd4fff243..fdecb8aac7af 100644 --- a/enzyme/test/Integration/ReverseMode/taylorlog.c +++ b/enzyme/test/Integration/ReverseMode/taylorlog.c @@ -7,8 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -//#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/blas_inline.h b/enzyme/test/Integration/blas_inline.h index 8ac0c16f51b3..b1b988190a5e 100644 --- a/enzyme/test/Integration/blas_inline.h +++ b/enzyme/test/Integration/blas_inline.h @@ -1,5 +1,8 @@ #include #include +#include +#include +#include typedef int32_t integer; typedef double doublereal; diff --git a/enzyme/test/Integration/test_utils.h b/enzyme/test/Integration/test_utils.h index afcf87d4471b..3226b61874ef 100644 --- a/enzyme/test/Integration/test_utils.h +++ b/enzyme/test/Integration/test_utils.h @@ -1,7 +1,22 @@ -#include -#include -#include -#include + + +#ifdef __cplusplus +extern "C" { +#endif +struct _IO_FILE; +extern struct _IO_FILE* stderr; +extern int fprintf(struct _IO_FILE *, const char*, ...); +extern int fflush(struct _IO_FILE *stream); +extern int printf(const char*, ...); +extern void abort(); +extern void free(void *); +extern void* malloc(unsigned long); +extern void *realloc( void *ptr, unsigned long new_size ); +extern void* memcpy( void* dest, const void* src, unsigned long count ); +extern void* memset( void* dest, int, unsigned long count ); +#ifdef __cplusplus +} +#endif extern #ifdef __cplusplus diff --git a/enzyme/test/MLIR/ForwardMode/inactive.mlir b/enzyme/test/MLIR/ForwardMode/inactive.mlir index 6ccb8d7033ee..10b7fcd61b27 100644 --- a/enzyme/test/MLIR/ForwardMode/inactive.mlir +++ b/enzyme/test/MLIR/ForwardMode/inactive.mlir @@ -1,11 +1,11 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme %s -allow-unregistered-dialect | FileCheck %s module { func.func @inactive(%x : f64) -> f64 { - // We don't have an interface implementation for "func", + // We don't have an interface implementation for "foo", // but we can see it's inactive from its lack of operands // and results. - func.func private @foo() + "test.foo"() : () -> () return %x : f64 } func.func @diff(%x : f64, %dx : f64) -> f64 { @@ -17,4 +17,4 @@ module { // Just check that we didn't trigger the error on there not being an interface // implementation. // CHECK-LABEL: func private @fwddiffeinactive -// CHECK: func private @foo +// CHECK: "test.foo"() From dc5eaa56b9fbb64aad911ffed27e7a59e88a9b32 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 13:12:23 +0000 Subject: [PATCH 010/106] Logabsgamma support (#1669) * logabsgamma * Fix type analysis * fixed * Add logabsgammaf --- enzyme/Enzyme/InstructionDerivatives.td | 30 +++++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 62 ++++++++++++++----- enzyme/test/Enzyme/ForwardMode/logabsgamma.ll | 28 +++++++++ enzyme/test/Enzyme/ReverseMode/logabsgamma.ll | 19 ++++-- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 42 +++++++++++++ 5 files changed, 159 insertions(+), 22 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/logabsgamma.ll diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 53fad8160993..00be8fe77e55 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -186,6 +186,12 @@ class PrependArgTypesFunc pretys_> { list pretys = pretys_; } +// Set return to arg[0] +// Same argument types +class ArgAsRetTypesFunc { + string name = name_; +} + // Specify that a given argument is inactive, aka not differentiable // By default this argument tells Enzyme that it must always be inactive // from the function semantics. @@ -461,6 +467,30 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["logabsgamma"], + [ + ( + ArrayRet (FMul (Call<(ArgAsRetTypesFunc<"digamma">), [ReadNone,NoUnwind]> $x), (DiffeRet) ), + (InactiveArg) + ) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["logabsgammaf"], + [ + ( + ArrayRet (FMul (Call<(ArgAsRetTypesFunc<"digammaf">), [ReadNone,NoUnwind]> $x), (DiffeRet) ), + (InactiveArg) + ) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : CallPattern<(Op $x), ["sinpi", "sinpif", "sinpil", "cospi", "cospif", "cospil"], [ diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index f2a3dc56cdf8..7a75daef2501 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -167,6 +167,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"yn", Intrinsic::not_intrinsic}, {"tgamma", Intrinsic::not_intrinsic}, {"lgamma", Intrinsic::not_intrinsic}, + {"logabsgamma", Intrinsic::not_intrinsic}, {"ceil", Intrinsic::ceil}, {"__nv_ceil", Intrinsic::ceil}, {"floor", Intrinsic::floor}, @@ -5305,23 +5306,52 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } else if (T->isVoidTy()) { } else if (auto ST = dyn_cast(T)) { assert(ST->getNumElements() >= 1); - for (size_t i = 1; i < ST->getNumElements(); ++i) { - assert(ST->getTypeAtIndex((unsigned)0) == ST->getTypeAtIndex(i)); - } - if (ST->getTypeAtIndex((unsigned)0)->isFloatingPointTy()) - updateAnalysis( - &call, - TypeTree(ConcreteType( - ST->getTypeAtIndex((unsigned)0)->getScalarType())) - .Only(-1, &call), - &call); - else if (ST->getTypeAtIndex((unsigned)0)->isIntegerTy()) { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), - &call); - } else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); + TypeTree TT; + auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); + for (size_t i = 0; i < ST->getNumElements(); ++i) { + auto T = ST->getTypeAtIndex(i); + ConcreteType CT(BaseType::Unknown); + + Value *vec[2] = { + ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), + ConstantInt::get(Type::getInt32Ty(call.getContext()), i)}; + auto ud = UndefValue::get(PointerType::getUnqual(ST)); + auto g2 = GetElementPtrInst::Create(ST, ud, vec); + APInt ai(DL.getIndexSizeInBits(0), 0); + g2->accumulateConstantOffset(DL, ai); + delete g2; + size_t Offset = ai.getZExtValue(); + + size_t nextOffset; + if (i + 1 == ST->getNumElements()) + nextOffset = (DL.getTypeSizeInBits(ST) + 7) / 8; + else { + Value *vec[2] = { + ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), + ConstantInt::get(Type::getInt32Ty(call.getContext()), i + 1)}; + auto ud = UndefValue::get(PointerType::getUnqual(ST)); + auto g2 = GetElementPtrInst::Create(ST, ud, vec); + APInt ai(DL.getIndexSizeInBits(0), 0); + g2->accumulateConstantOffset(DL, ai); + delete g2; + nextOffset = ai.getZExtValue(); + } + + if (T->isFloatingPointTy()) { + CT = T; + } else if (T->isIntegerTy()) { + CT = BaseType::Integer; + } + if (CT != BaseType::Unknown) { + TypeTree mid = TypeTree(CT).Only(-1, &call); + TT |= mid.ShiftIndices(DL, /*init offset*/ 0, + /*maxSize*/ nextOffset - Offset, + /*addOffset*/ Offset); + } } + auto Size = (DL.getTypeSizeInBits(ST) + 7) / 8; + TT.CanonicalizeInPlace(Size, DL); + updateAnalysis(&call, TT, &call); } else if (auto AT = dyn_cast(T)) { assert(AT->getNumElements() >= 1); if (AT->getElementType()->isFloatingPointTy()) diff --git a/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll b/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll new file mode 100644 index 000000000000..280d20db590e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define { double, i64 } @tester(double %x) { +entry: + %a = call { double, i64 } @logabsgamma(double %x) + ret { double, i64 } %a +} + +define { double, i64 } @test_derivative(double %x, double %dx) { +entry: + %0 = tail call { double, i64 } (...) @__enzyme_fwddiff({ double, i64 } (double)* nonnull @tester, double %x, double %dx) + ret { double, i64 } %0 +} + +declare { double, i64 } @logabsgamma(double) + +; Function Attrs: nounwind +declare { double, i64 } @__enzyme_fwddiff(...) + +; CHECK: define internal { double, i64 } @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @digamma(double %x) +; CHECK-NEXT: %1 = fmul fast double %0, %"x'" +; CHECK-NEXT: %2 = insertvalue { double, i64 } undef, double %1, 0 +; CHECK-NEXT: ret { double, i64 } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll index 02e0f40cd5bc..7682abf13dd9 100644 --- a/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll +++ b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll @@ -1,8 +1,6 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s -; XFAIL: * - ; Function Attrs: nounwind readnone uwtable define double @tester(double %x) { entry: @@ -24,8 +22,17 @@ declare double @__enzyme_autodiff(double (double)*, ...) ; CHECK: define internal { double } @diffetester(double %x, double %differeturn) ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call fast double @cosh(double %x) -; CHECK-NEXT: %1 = fmul fast double %differeturn, %0 -; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0 -; CHECK-NEXT: ret { double } %2 +; CHECK-NEXT: %"a'de" = alloca { double, i64 }, align 8 +; CHECK-NEXT: store { double, i64 } zeroinitializer, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: %0 = getelementptr inbounds { double, i64 }, { double, i64 }* %"a'de", i32 0, i32 0 +; CHECK-NEXT: %1 = load double, double* %0, align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %differeturn +; CHECK-NEXT: store double %2, double* %0, align 8 +; CHECK-NEXT: %3 = load { double, i64 }, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: store { double, i64 } zeroinitializer, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: %4 = call fast double @digamma(double %x) +; CHECK-NEXT: %5 = extractvalue { double, i64 } %3, 0 +; CHECK-NEXT: %6 = fmul fast double %4, %5 +; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0 +; CHECK-NEXT: ret { double } %7 ; CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index c30f8870063e..bf6a00d07cd2 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -119,6 +119,21 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, << ")->getCallingConv();\n"; return; } + if (opName == "ArgAsRetTypesFunc" || + Def->isSubClassOf("ArgAsRetTypesFunc")) { + os << curIndent << "auto " << FT << "_old = cast(&" << origName + << ")->getFunctionType();\n"; + os << curIndent << "auto " << FT << " = FunctionType::get(" << FT + << "_old->params()[0], " << FT << "_old->params(), " << FT + << "_old->isVarArg());\n"; + os << curIndent << "auto " << callval + << " = gutils->oldFunc->getParent()->getOrInsertFunction("; + os << Def->getValueInit("name")->getAsString(); + os << ", " << FT << ", called->getAttributes()).getCallee();\n"; + os << curIndent << "auto " << cconv << " = cast(&" << origName + << ")->getCallingConv();\n"; + return; + } } assert(0 && "Unhandled function"); } @@ -954,6 +969,13 @@ void handleUse( foundDiffRet = true; return; } + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + return; + } + if (!Def->isSubClassOf("Operation")) { + errs() << *resultTree << "\n"; + errs() << opName << " " << *Def << "\n"; + } assert(Def->isSubClassOf("Operation")); bool usesPrimal = Def->getValueAsBit("usesPrimal"); bool usesShadow = Def->getValueAsBit("usesShadow"); @@ -1527,6 +1549,9 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } return; } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } os << curIndent << INDENT << "{\n"; if (intrinsic == MLIRDerivatives) os << curIndent << INDENT << INDENT << "mlir::Value itmp = "; @@ -1764,6 +1789,9 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } return; } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } const char *curIndent = " "; os << curIndent << "{\n"; if (intrinsic == MLIRDerivatives) @@ -1859,10 +1887,24 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } if (intrinsic != MLIRDerivatives) { + os << " auto found = gutils->invertedPointers.find(&(" << origName + << "));\n"; + os << " if (found != gutils->invertedPointers.end()) {\n"; + os << " PHINode* PN = cast(&*found->second);\n"; + os << " gutils->invertedPointers.erase(found);\n"; + os << " gutils->erase(PN);\n"; + os << " }\n"; os << " break;\n"; os << " }\n"; os << " case DerivativeMode::ReverseModePrimal:{\n"; + os << " auto found = gutils->invertedPointers.find(&(" << origName + << "));\n"; + os << " if (found != gutils->invertedPointers.end()) {\n"; + os << " PHINode* PN = cast(&*found->second);\n"; + os << " gutils->invertedPointers.erase(found);\n"; + os << " gutils->erase(PN);\n"; + os << " }\n"; // TODO os << " break;\n"; os << " }\n"; From 3c60e4317595d32b87353019b48f62a32657cd1e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 16:53:58 +0000 Subject: [PATCH 011/106] Fix symbol link collision (#1671) --- enzyme/Enzyme/ActivityAnalysis.cpp | 14 +++++++------- enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index fa4dd06bb68c..7b251e154f94 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -104,7 +104,7 @@ cl::opt EnzymeEnableRecursiveHypotheses( #include // clang-format off -const char *KnownInactiveFunctionsStartingWith[] = { +static const char *KnownInactiveFunctionsStartingWith[] = { "f90io", "$ss5print", "_ZTv0_n24_NSoD", //"1Ev, 0Ev @@ -113,11 +113,11 @@ const char *KnownInactiveFunctionsStartingWith[] = { "_ZNSaIcEC1Ev", }; -const char *KnownInactiveFunctionsContains[] = { +static const char *KnownInactiveFunctionsContains[] = { "__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"}; -const StringSet<> InactiveGlobals = { +static const StringSet<> InactiveGlobals = { "small_typeof", "ompi_request_null", "ompi_mpi_double", @@ -171,7 +171,7 @@ const llvm::StringMap MPIInactiveCommAllocators = { // Instructions which themselves are inactive // the returned value, however, may still be active -const StringSet<> KnownInactiveFunctionInsts = { +static const StringSet<> KnownInactiveFunctionInsts = { "__dynamic_cast", "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", @@ -180,7 +180,7 @@ const StringSet<> KnownInactiveFunctionInsts = { "jl_ptr_to_array", "jl_ptr_to_array_1d"}; -const StringSet<> KnownInactiveFunctions = { +static const StringSet<> KnownInactiveFunctions = { "mpfr_greater_p", "__nv_isnand", "__nv_isnanf", @@ -293,7 +293,7 @@ const StringSet<> KnownInactiveFunctions = { "floorl" }; -const std::set KnownInactiveIntrinsics = { +static const std::set KnownInactiveIntrinsics = { #if LLVM_VERSION_MAJOR >= 12 Intrinsic::experimental_noalias_scope_decl, #endif @@ -342,7 +342,7 @@ const std::set KnownInactiveIntrinsics = { Intrinsic::is_constant, Intrinsic::memset}; -const char *DemangledKnownInactiveFunctionsStartingWith[] = { +static const char *DemangledKnownInactiveFunctionsStartingWith[] = { // TODO this returns allocated memory and thus can be an active value // "std::allocator", "std::chrono::_V2::steady_clock::now", diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 5f3cd22db72d..349213ac2fbf 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -20,7 +20,7 @@ #include "Interfaces/AutoDiffOpInterface.h" -const char *KnownInactiveFunctionsStartingWith[] = { +static const char *KnownInactiveFunctionsStartingWith[] = { "f90io", "$ss5print", "_ZTv0_n24_NSoD", //"1Ev, 0Ev @@ -29,11 +29,11 @@ const char *KnownInactiveFunctionsStartingWith[] = { "_ZNSaIcEC1Ev", }; -const char *KnownInactiveFunctionsContains[] = { +static const char *KnownInactiveFunctionsContains[] = { "__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"}; -const std::set InactiveGlobals = { +static const std::set InactiveGlobals = { "ompi_request_null", "ompi_mpi_double", "ompi_mpi_comm_world", "stderr", "stdout", "stdin", "_ZSt3cin", "_ZSt4cout", "_ZSt5wcout", "_ZSt4cerr", "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", @@ -57,7 +57,7 @@ const std::set InactiveGlobals = { "_ZTVN10__cxxabiv117__class_type_infoE", "_ZTVN10__cxxabiv121__vmi_class_type_infoE"}; -const std::map MPIInactiveCommAllocators = { +static const std::map MPIInactiveCommAllocators = { {"MPI_Graph_create", 5}, {"MPI_Comm_split", 2}, {"MPI_Intercomm_create", 6}, @@ -75,7 +75,7 @@ const std::map MPIInactiveCommAllocators = { // Instructions which themselves are inactive // the returned value, however, may still be active -const std::set KnownInactiveFunctionInsts = { +static const std::set KnownInactiveFunctionInsts = { "__dynamic_cast", "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", @@ -84,7 +84,7 @@ const std::set KnownInactiveFunctionInsts = { "jl_ptr_to_array", "jl_ptr_to_array_1d"}; -const std::set KnownInactiveFunctions = { +static const std::set KnownInactiveFunctions = { "abort", "time", "memcmp", @@ -165,7 +165,7 @@ const std::set KnownInactiveFunctions = { "logbl", }; -const char *DemangledKnownInactiveFunctionsStartingWith[] = { +static const char *DemangledKnownInactiveFunctionsStartingWith[] = { // TODO this returns allocated memory and thus can be an active value // "std::allocator", "std::string", From b8c1936802aa8ec75115f37b60f50b4482fa5748 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 17:51:02 +0000 Subject: [PATCH 012/106] Fix integration tests (#1670) --- enzyme/test/Integration/ReverseMode/cmplx.cpp | 4 ---- enzyme/test/Integration/ReverseMode/eigentensor.cpp | 1 + enzyme/test/Integration/test_utils.h | 8 ++++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/enzyme/test/Integration/ReverseMode/cmplx.cpp b/enzyme/test/Integration/ReverseMode/cmplx.cpp index 73559beb94fb..bec8aa0d831b 100644 --- a/enzyme/test/Integration/ReverseMode/cmplx.cpp +++ b/enzyme/test/Integration/ReverseMode/cmplx.cpp @@ -11,10 +11,6 @@ #include "../test_utils.h" -#include -#include - -#include #include // std::complex, std::abs, std::arg void __enzyme_autodiff(...); diff --git a/enzyme/test/Integration/ReverseMode/eigentensor.cpp b/enzyme/test/Integration/ReverseMode/eigentensor.cpp index 28feb236f4db..6e2da6d65818 100644 --- a/enzyme/test/Integration/ReverseMode/eigentensor.cpp +++ b/enzyme/test/Integration/ReverseMode/eigentensor.cpp @@ -16,6 +16,7 @@ #include "../test_utils.h" +#include void memcpy(float* __restrict dst, float* __restrict src, size_t count) { for(size_t i=0; i +#include +#include +#else struct _IO_FILE; extern struct _IO_FILE* stderr; extern int fprintf(struct _IO_FILE *, const char*, ...); @@ -14,8 +16,6 @@ extern void* malloc(unsigned long); extern void *realloc( void *ptr, unsigned long new_size ); extern void* memcpy( void* dest, const void* src, unsigned long count ); extern void* memset( void* dest, int, unsigned long count ); -#ifdef __cplusplus -} #endif extern From 66a3bb6b394b8824e94913898368d0bb34550d75 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 6 Feb 2024 18:12:07 +0000 Subject: [PATCH 013/106] Packaging fixes for julia on 16 (#1672) --- .packaging/build_tarballs.jl | 4 +- enzyme/BCLoad/CMakeLists.txt | 6 +-- enzyme/Enzyme/AdjointGenerator.h | 47 ++++----------------- enzyme/Enzyme/CallDerivatives.cpp | 6 +-- enzyme/Enzyme/DiffeGradientUtils.cpp | 25 +++-------- enzyme/Enzyme/Enzyme.cpp | 22 ++-------- enzyme/Enzyme/EnzymeLogic.cpp | 7 +-- enzyme/Enzyme/GradientUtils.cpp | 13 ++---- enzyme/Enzyme/LibraryFuncs.h | 6 +-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 11 +---- 10 files changed, 30 insertions(+), 117 deletions(-) diff --git a/.packaging/build_tarballs.jl b/.packaging/build_tarballs.jl index 67d214bf2c86..009b48e756e4 100644 --- a/.packaging/build_tarballs.jl +++ b/.packaging/build_tarballs.jl @@ -28,7 +28,7 @@ platforms = expand_cxxstring_abis(supported_platforms(; experimental=true)) script = raw""" cd Enzyme -if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15.asserts* ]]; then +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+17.asserts* ]]; then # LLVM 15 requires macOS SDK 10.14. pushd $WORKSPACE/srcdir/MacOSX10.*.sdk rm -rf /opt/${target}/${target}/sys-root/System @@ -117,7 +117,7 @@ for llvm_version in llvm_versions, llvm_assertions in (false, true) for platform in platforms augmented_platform = deepcopy(platform) augmented_platform[LLVM.platform_name] = LLVM.platform(llvm_version, llvm_assertions) - gcc_version = version > v"15" ? v"10" : v"8" + gcc_version = llvm_version > v"15" ? v"10" : v"8" should_build_platform(triplet(augmented_platform)) || continue push!(builds, (; dependencies, products, diff --git a/enzyme/BCLoad/CMakeLists.txt b/enzyme/BCLoad/CMakeLists.txt index f2b04238fb67..6516a46dfbd5 100644 --- a/enzyme/BCLoad/CMakeLists.txt +++ b/enzyme/BCLoad/CMakeLists.txt @@ -4,10 +4,10 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(BC_LOAD_FLAGS "" CACHE STRING "") set(BC_LOAD_HEADER "" CACHE STRING "") -if (${LLVM_VERSION_MAJOR} LESS 15) - set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") -else() +if (${LLVM_VERSION_MAJOR} EQUAL 16) set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS} -Xclang -no-opaque-pointers") +else() + set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") endif() if (APPLE) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 80d6c851e2a2..debac4dd496f 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4859,18 +4859,9 @@ class AdjointGenerator } } Value *tape = nullptr; -#if LLVM_VERSION_MAJOR >= 16 - if (tapeIdx.has_value()) -#else - if (tapeIdx.hasValue()) -#endif - { + if (tapeIdx) { -#if LLVM_VERSION_MAJOR >= 16 - auto idx = tapeIdx.value(); -#else - auto idx = tapeIdx.getValue(); -#endif + auto idx = *tapeIdx; FunctionType *FT = subdata->fn->getFunctionType(); tape = BuilderZ.CreatePHI( @@ -5416,17 +5407,8 @@ class AdjointGenerator if (!augmentcall->getType()->isVoidTy()) augmentcall->setName(call.getName() + "_augmented"); -#if LLVM_VERSION_MAJOR >= 16 - if (tapeIdx.has_value()) -#else - if (tapeIdx.hasValue()) -#endif - { -#if LLVM_VERSION_MAJOR >= 16 - auto tval = tapeIdx.value(); -#else - auto tval = tapeIdx.getValue(); -#endif + if (tapeIdx) { + auto tval = *tapeIdx; tape = (tval == -1) ? augmentcall : BuilderZ.CreateExtractValue( augmentcall, {(unsigned)tval}, "subcache"); @@ -5445,11 +5427,7 @@ class AdjointGenerator Value *dcall = nullptr; assert(returnIdx); assert(augmentcall); -#if LLVM_VERSION_MAJOR >= 16 - auto rval = returnIdx.value(); -#else - auto rval = returnIdx.getValue(); -#endif + auto rval = *returnIdx; dcall = (rval < 0) ? augmentcall : BuilderZ.CreateExtractValue(augmentcall, {(unsigned)rval}); @@ -5520,13 +5498,8 @@ class AdjointGenerator // assert(!tape); // assert(subdata); if (!tape) { -#if LLVM_VERSION_MAJOR >= 16 - assert(tapeIdx.has_value()); - auto tval = tapeIdx.value(); -#else - assert(tapeIdx.hasValue()); - auto tval = tapeIdx.getValue(); -#endif + assert(tapeIdx); + auto tval = *tapeIdx; tape = BuilderZ.CreatePHI( (tapeIdx == -1) ? FT->getReturnType() : cast(FT->getReturnType()) @@ -5585,11 +5558,7 @@ class AdjointGenerator if (Mode == DerivativeMode::ReverseModeCombined || Mode == DerivativeMode::ReverseModePrimal) { -#if LLVM_VERSION_MAJOR >= 16 - auto drval = differetIdx.value(); -#else - auto drval = differetIdx.getValue(); -#endif + auto drval = *differetIdx; newip = (drval < 0) ? augmentcall : BuilderZ.CreateExtractValue(augmentcall, diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index efd802db406b..02719a9a4660 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -2254,11 +2254,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( } if (auto blas = extractBLAS(funcName)) { -#if LLVM_VERSION_MAJOR >= 16 - if (handleBLAS(call, called, blas.value(), overwritten_args)) -#else - if (handleBLAS(call, called, blas.getValue(), overwritten_args)) -#endif + if (handleBLAS(call, called, *blas, overwritten_args)) return true; } diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 68a7e02d048d..f3ca78dacea1 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -966,14 +966,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, if (alignv) { if (start != 0) { // todo make better alignment calculation -#if LLVM_VERSION_MAJOR >= 16 - assert(alignv.value().value() != 0); - if (start % alignv.value().value() != 0) -#else - assert(alignv.getValue().value() != 0); - if (start % alignv.getValue().value() != 0) -#endif - { + assert((*alignv).value() != 0); + if (start % (*alignv).value() != 0) { alignv = Align(1); } } @@ -1007,13 +1001,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, if (alignv) { if (start != 0) { // todo make better alignment calculation -#if LLVM_VERSION_MAJOR >= 16 - assert(alignv.value().value() != 0); - if (start % alignv.value().value() != 0) { -#else - assert(alignv.getValue().value() != 0); - if (start % alignv.getValue().value() != 0) { -#endif + assert((*alignv).value() != 0); + if (start % (*alignv).value() != 0) { alignv = Align(1); } } @@ -1093,11 +1082,7 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc())); if (align) { -#if LLVM_VERSION_MAJOR >= 16 - auto alignv = align ? align.value().value() : 0; -#else - auto alignv = align ? align.getValue().value() : 0; -#endif + auto alignv = align ? (*align).value() : 0; if (alignv != 0) { if (start != 0) { // todo make better alignment calculation diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index e402c1656248..c9b73e598222 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1887,15 +1887,9 @@ class EnzymeBase { #endif } -#if LLVM_VERSION_MAJOR >= 16 return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.value(), sizeOnly, + byVal, constants, fn, mode, *options, sizeOnly, calls); -#else - return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.getValue(), - sizeOnly, calls); -#endif } bool HandleProbProg(CallInst *CI, ProbProgMode mode, @@ -2025,17 +2019,9 @@ class EnzymeBase { #endif } -#if LLVM_VERSION_MAJOR >= 16 - bool status = - HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, - constants, newFunc, DerivativeMode::ReverseModeCombined, - opt.value(), false, calls); -#else - bool status = - HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, - constants, newFunc, DerivativeMode::ReverseModeCombined, - opt.getValue(), false, calls); -#endif + bool status = HandleAutoDiff( + CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, + newFunc, DerivativeMode::ReverseModeCombined, *opt, false, calls); delete interface; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 9ad444e4bd3a..580b59f3c557 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4282,13 +4282,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } auto store = entryBuilder.CreateStore( Constant::getNullValue(g.getValueType()), &g); -#if LLVM_VERSION_MAJOR >= 16 - if (g.getAlign()) - store->setAlignment(g.getAlign().value()); -#else if (g.getAlign()) - store->setAlignment(g.getAlign().getValue()); -#endif + store->setAlignment(*g.getAlign()); } } if (sharedBlock) { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 6adfc1bd6b78..0b6947887f6f 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -534,17 +534,10 @@ DebugLoc GradientUtils::getNewFromOriginal(const DebugLoc L) const { return L; assert(originalToNewFn.hasMD()); auto opt = originalToNewFn.getMappedMD(L.getAsMDNode()); -#if LLVM_VERSION_MAJOR >= 16 - if (!opt.has_value()) - return L; - assert(opt.has_value()); - return DebugLoc(cast(opt.value())); -#else - if (!opt.hasValue()) + if (!opt) return L; - assert(opt.hasValue()); - return DebugLoc(cast(*opt.getPointer())); -#endif + assert(opt); + return DebugLoc(cast(*opt)); } Value *GradientUtils::getNewFromOriginal(const Value *originst) const { diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index 5fd3992cfad5..1a8ebce38eee 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -209,11 +209,7 @@ static inline void zeroKnownAllocation(llvm::IRBuilder<> &bb, } if (funcName == "enzyme_allocator") { auto index = getAllocationIndexFromCall(orig); -#if LLVM_VERSION_MAJOR >= 16 - allocSize = argValues[index.value()]; -#else - allocSize = argValues[index.getValue()]; -#endif + allocSize = argValues[*index]; } Value *dst_arg = toZero; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 7a75daef2501..a5db987b5a74 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4304,17 +4304,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { StringRef funcName = getFuncNameFromCall(&call); auto blasMetaData = extractBLAS(funcName); -#if LLVM_VERSION_MAJOR >= 16 - if (blasMetaData.has_value()) { - BlasInfo blas = blasMetaData.value(); + if (blasMetaData) { + BlasInfo blas = *blasMetaData; #include "BlasTA.inc" } -#else - if (blasMetaData.hasValue()) { - BlasInfo blas = blasMetaData.getValue(); -#include "BlasTA.inc" - } -#endif // When compiling Enzyme against standard LLVM, and not Intel's // modified version of LLVM, the intrinsic `llvm.intel.subscript` is From ed28cb68ccf47b5ff2594421ad62f878be562b03 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Feb 2024 17:32:21 +0000 Subject: [PATCH 014/106] [MLIR] make wrapping pass (#1673) --- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 107 +++++++++++++++++++ enzyme/Enzyme/MLIR/Passes/Passes.h | 6 ++ enzyme/Enzyme/MLIR/Passes/Passes.td | 58 ++++++++++ enzyme/test/MLIR/ForwardMode/wrap.mlir | 16 +++ 5 files changed, 188 insertions(+) create mode 100644 enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp create mode 100644 enzyme/test/MLIR/ForwardMode/wrap.mlir diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 03c91683adff..8a18dc33d195 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_doc(Passes EnzymePasses ./ -gen-pass-doc) add_mlir_dialect_library(MLIREnzymeTransforms EnzymeMLIRPass.cpp + EnzymeWrapPass.cpp PrintActivityAnalysis.cpp PrintAliasAnalysis.cpp EnzymeToMemRef.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp new file mode 100644 index 000000000000..235a061c6088 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -0,0 +1,107 @@ +//===- EnzymeWrapPass.cpp - Replace calls with their derivatives ------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to create wrapper functions which differentiate +// ops. +//===----------------------------------------------------------------------===// + +#include "Dialect/Ops.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "PassDetails.h" +#include "Passes/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace enzyme; + +namespace { +struct DifferentiateWrapperPass + : public DifferentiateWrapperPassBase { + + void runOnOperation() override { + MEnzymeLogic Logic; + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(getOperation()); + + Operation *symbolOp = nullptr; + if (infn != "") + symbolOp = symbolTable.lookupSymbolIn( + getOperation(), StringAttr::get(getOperation()->getContext(), infn)); + else { + for (auto &op : getOperation()->getRegion(0).front()) { + auto fn = dyn_cast(symbolOp); + if (!fn) + continue; + assert(symbolOp == nullptr); + symbolOp = &op; + } + } + auto fn = cast(symbolOp); + SmallVector split; + StringRef(argTys.getValue().data(), argTys.getValue().size()) + .split(split, ','); + std::vector constants; + for (auto &str : split) { + if (str == "enzyme_dup") + constants.push_back(DIFFE_TYPE::DUP_ARG); + else if (str == "enzyme_const") + constants.push_back(DIFFE_TYPE::CONSTANT); + else if (str == "enzyme_dupnoneed") + constants.push_back(DIFFE_TYPE::DUP_NONEED); + else if (str == "enzyme_out") + constants.push_back(DIFFE_TYPE::OUT_DIFF); + else + assert(0 && " unknown constant"); + } + + DIFFE_TYPE retType = retTy.getValue(); + MTypeAnalysis TA; + auto type_args = TA.getAnalyzedTypeInfo(fn); + + bool freeMemory = true; + size_t width = 1; + + std::vector volatile_args; + for (auto &a : fn.getFunctionBody().getArguments()) { + (void)a; + volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); + } + + FunctionOpInterface newFunc = Logic.CreateForwardDiff( + fn, retType, constants, TA, + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + if (outfn == "") { + fn->erase(); + } else { + SymbolTable::setSymbolName(cast(newFunc), + (std::string)outfn); + } + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createDifferentiateWrapperPass() { + new DifferentiateWrapperPass(); + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index d5e821bac643..7e64f7718988 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -8,9 +8,13 @@ #ifndef ENZYME_PASSES_H #define ENZYME_PASSES_H +#include "../../Utils.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Pass/Pass.h" #include + +#include "Enzyme/MLIR/Dialect/Dialect.h" + namespace mlir { class PatternRewriter; class RewritePatternSet; @@ -18,6 +22,8 @@ class DominanceInfo; namespace enzyme { std::unique_ptr createDifferentiatePass(); +std::unique_ptr createDifferentiateWrapperPass(); + std::unique_ptr createPrintActivityAnalysisPass(); std::unique_ptr createPrintAliasAnalysisPass(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 1289ddca9bf4..432b0938f763 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -19,6 +19,64 @@ def DifferentiatePass : Pass<"enzyme"> { let constructor = "mlir::enzyme::createDifferentiatePass()"; } +def DifferentiateWrapperPass : Pass<"enzyme-wrap"> { + let summary = "Add wrapper function to be differentiated"; + let dependentDialects = [ + "cf::ControlFlowDialect", + "enzyme::EnzymeDialect" + ]; + let constructor = "mlir::enzyme::createDifferentiateWrapperPass()"; + let options = [ + Option< + /*C++ variable name=*/"infn", + /*CLI argument=*/"infn", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Input function to differentiate" + >, + Option< + /*C++ variable name=*/"outfn", + /*CLI argument=*/"outfn", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Output function to differentiate" + >, + Option< + /*C++ variable name=*/"mode", + /*CLI argument=*/"mode", + /*type=*/"DerivativeMode", + /*default=*/"DerivativeMode::ForwardMode", + /*description=*/"mode to differentiate", +[{::llvm::cl::values( + clEnumValN(DerivativeMode::ForwardMode, "ForwardMode", "ForwardMode (default)"), + clEnumValN(DerivativeMode::ReverseModeCombined, "ReverseModeCombined", "Combined ReverseMode"), + clEnumValN(DerivativeMode::ReverseModePrimal, "ReverseModePrimal", "Forward Pass of ReverseMode"), + clEnumValN(DerivativeMode::ReverseModeGradient, "ReverseModeGradient", "Backward Pass of ReverseMode") + )}] + >, + Option< + /*C++ variable name=*/"retTy", + /*CLI argument=*/"retTy", + /*type=*/"DIFFE_TYPE", + /*default=*/"DIFFE_TYPE::DUP_ARG", + /*description=*/"activity of the return", +[{::llvm::cl::values( + clEnumValN(DIFFE_TYPE::DUP_ARG, "enzyme_dup", "Duplicated (default)"), + clEnumValN(DIFFE_TYPE::OUT_DIFF, "enzyme_out", "Active"), + clEnumValN(DIFFE_TYPE::CONSTANT, "enzyme_const", "Constant"), + clEnumValN(DIFFE_TYPE::DUP_NONEED, "enzyme_dupnoneed", "Duplicated noneed") + )}] + >, + Option< + /*C++ variable name=*/"argTys", + /*CLI argument=*/"argTys", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"The activity of the arguments" + >, + ]; +} + def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> { let summary = "Print the results of activity analysis"; let constructor = "mlir::enzyme::createPrintActivityAnalysisPass()"; diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir new file mode 100644 index 000000000000..7cb0eb82bf4b --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -0,0 +1,16 @@ +// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup, mode=ForwardMode" %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } +} + +// CHECK: func.func private @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: return %[[i2]] : f64 +// CHECK-NEXT: } From 3dd455b5c5009fda6782e0307b4c10f6af741227 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Feb 2024 19:31:04 +0000 Subject: [PATCH 015/106] Fix tablegen for ll16 without optional support (#1674) --- enzyme/tools/enzyme-tblgen/blasDeclUpdater.h | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h index aef73e479f2f..aebf6d7aec0d 100644 --- a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h @@ -195,17 +195,10 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) { os << " auto name = getFuncName(&F);\n"; os << " auto changed = false;\n"; os << " auto blasMetaData = extractBLAS(name);\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " if (F.empty() && blasMetaData.has_value()) {\n"; - os << " attributeBLAS(blasMetaData.value(), &F);\n"; - os << " changed = true;\n"; - os << " }\n"; - os << " #else\n"; - os << " if (F.empty() && blasMetaData.hasValue()) {\n"; - os << " attributeBLAS(blasMetaData.getValue(), &F);\n"; - os << " changed = true;\n"; - os << " }\n"; - os << " #endif\n"; + os << " if (F.empty() && blasMetaData) {\n"; + os << " attributeBLAS(*blasMetaData, &F);\n"; + os << " changed = true;\n"; + os << " }\n"; { const auto &patterns = RK.getAllDerivedDefinitions("CallPattern"); for (Record *pattern : patterns) { From e243d05c8a2f938c89b5e917a2a37f64aa3b8bfe Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 7 Feb 2024 21:28:38 +0000 Subject: [PATCH 016/106] Restore mlir build (#1675) --- enzyme/Enzyme/MLIR/Passes/Passes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 7e64f7718988..25362d01294a 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -13,7 +13,7 @@ #include "mlir/Pass/Pass.h" #include -#include "Enzyme/MLIR/Dialect/Dialect.h" +#include "Dialect/Dialect.h" namespace mlir { class PatternRewriter; From ea07b775faeca2d63ab71df74d8269c6ef9f4ef6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 8 Feb 2024 10:11:52 -0500 Subject: [PATCH 017/106] More macos on llvm16 fix (#1676) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 6 +----- enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h | 12 ++---------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index a5db987b5a74..b9e670326fd4 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4937,11 +4937,7 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } if (auto opidx = getAllocationIndexFromCall(&call)) { auto ptr = TypeTree(BaseType::Pointer); -#if LLVM_VERSION_MAJOR >= 15 - unsigned index = (size_t)opidx.value(); -#else - unsigned index = (size_t)opidx.getValue(); -#endif + unsigned index = (size_t)*opidx; if (auto CI = dyn_cast(call.getOperand(index))) { auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); auto LoadSize = CI->getZExtValue(); diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index b3d64f2c7ab9..51e883e3b804 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -182,11 +182,7 @@ void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { emitDiffUse(RK, os, CallDerivatives); os << " auto blasMetaData = extractBLAS(funcName);\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " if (blasMetaData.has_value())\n"; - os << " #else\n"; - os << " if (blasMetaData.hasValue())\n"; - os << " #endif\n"; + os << " if (blasMetaData)\n"; os << " {\n"; os << " auto Mode = gutils->mode;\n"; os << " const bool cacheMode = (Mode != DerivativeMode::ForwardMode);\n"; @@ -198,11 +194,7 @@ void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { os << " assert(found != gutils->overwritten_args_map_ptr->end());\n"; os << " overwritten_args_ptr = &found->second;\n"; os << " }\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " BlasInfo blas = blasMetaData.value();\n"; - os << " #else\n"; - os << " BlasInfo blas = blasMetaData.getValue();\n"; - os << " #endif\n"; + os << " BlasInfo blas = *blasMetaData;\n"; for (auto &&newPattern : newBlasPatterns) { emit_BLASDiffUse(newPattern, os); } From 21053d4122e90cbac3a121286302864627050067 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 8 Feb 2024 10:40:35 -0500 Subject: [PATCH 018/106] bcload 15 or 16 fix (#1677) --- enzyme/BCLoad/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/BCLoad/CMakeLists.txt b/enzyme/BCLoad/CMakeLists.txt index 6516a46dfbd5..f88f515ec77f 100644 --- a/enzyme/BCLoad/CMakeLists.txt +++ b/enzyme/BCLoad/CMakeLists.txt @@ -4,7 +4,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(BC_LOAD_FLAGS "" CACHE STRING "") set(BC_LOAD_HEADER "" CACHE STRING "") -if (${LLVM_VERSION_MAJOR} EQUAL 16) +if (${LLVM_VERSION_MAJOR} EQUAL 15 OR ${LLVM_VERSION_MAJOR} EQUAL 16) set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS} -Xclang -no-opaque-pointers") else() set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") From bd36baefd93d1a42f31f0ae0f8d4df1165188289 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Fri, 9 Feb 2024 08:25:39 +0900 Subject: [PATCH 019/106] [Truncate] Handle libm calls (#1636) * Handle instrinsics * Add integration test * Add common handler for math intrinsics * Format EnzymeLogic.cpp --- enzyme/Enzyme/EnzymeLogic.cpp | 54 +++++---- enzyme/test/Integration/CMakeLists.txt | 1 + .../test/Integration/Truncate/CMakeLists.txt | 9 ++ enzyme/test/Integration/Truncate/simple.cpp | 111 ++++++++++++++++++ 4 files changed, 152 insertions(+), 23 deletions(-) create mode 100644 enzyme/test/Integration/Truncate/CMakeLists.txt create mode 100644 enzyme/test/Integration/Truncate/simple.cpp diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 580b59f3c557..dc5e511b73fa 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -29,6 +29,7 @@ //===----------------------------------------------------------------------===// #include "ActivityAnalysis.h" #include "AdjointGenerator.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/Intrinsics.h" #if LLVM_VERSION_MAJOR >= 16 @@ -5075,18 +5076,17 @@ class TruncateGenerator : public llvm::InstVisitor { return; } void visitFenceInst(llvm::FenceInst &FI) { return; } - void visitIntrinsicInst(llvm::IntrinsicInst &II) { - SmallVector orig_ops(II.arg_size()); - for (unsigned i = 0; i < II.arg_size(); ++i) - orig_ops[i] = II.getOperand(i); - if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops)) - return; + + bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) { + SmallVector orig_ops(CI.arg_size()); + for (unsigned i = 0; i < CI.arg_size(); ++i) + orig_ops[i] = CI.getOperand(i); bool hasFromType = false; - auto newI = cast(getNewFromOriginal(&II)); + auto newI = cast(getNewFromOriginal(&CI)); IRBuilder<> B(newI); - SmallVector new_ops(II.arg_size()); - for (unsigned i = 0; i < II.arg_size(); ++i) { + SmallVector new_ops(CI.arg_size()); + for (unsigned i = 0; i < CI.arg_size(); ++i) { if (orig_ops[i]->getType() == getFromType()) { new_ops[i] = truncate(B, getNewFromOriginal(orig_ops[i])); hasFromType = true; @@ -5094,27 +5094,29 @@ class TruncateGenerator : public llvm::InstVisitor { new_ops[i] = getNewFromOriginal(orig_ops[i]); } } - Type *retTy = II.getType(); - if (II.getType() == getFromType()) { + Type *retTy = CI.getType(); + if (CI.getType() == getFromType()) { hasFromType = true; retTy = getToType(); } if (!hasFromType) - return; + return false; // TODO check that the intrinsic is overloaded CallInst *intr; - Value *nres = intr = createIntrinsicCall(B, II.getIntrinsicID(), retTy, - new_ops, &II, II.getName()); - if (II.getType() == getFromType()) + Value *nres = intr = + createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); + if (CI.getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); newI->replaceAllUsesWith(nres); newI->eraseFromParent(); - - return; + return true; + } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { + handleIntrinsic(II, II.getIntrinsicID()); } void visitReturnInst(llvm::ReturnInst &I) { return; } @@ -5201,18 +5203,24 @@ class TruncateGenerator : public llvm::InstVisitor { return v; } // Return - void visitCallInst(llvm::CallInst &call) { + void visitCallInst(llvm::CallInst &CI) { + Intrinsic::ID ID; + StringRef funcName = getFuncNameFromCall(const_cast(&CI)); + if (isMemFreeLibMFunction(funcName, &ID)) + if (handleIntrinsic(CI, ID)) + return; + using namespace llvm; - CallInst *const newCall = cast(getNewFromOriginal(&call)); + CallInst *const newCall = cast(getNewFromOriginal(&CI)); IRBuilder<> BuilderZ(newCall); - if (auto called = call.getCalledFunction()) - if (handleKnownCalls(call, called, getFuncNameFromCall(&call), newCall)) + if (auto called = CI.getCalledFunction()) + if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall)) return; - RequestContext ctx(&call, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(call.getCalledOperand())); + RequestContext ctx(&CI, &BuilderZ); + auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); newCall->setCalledOperand(val); return; } diff --git a/enzyme/test/Integration/CMakeLists.txt b/enzyme/test/Integration/CMakeLists.txt index 7a14214b46cd..98171f188d86 100644 --- a/enzyme/test/Integration/CMakeLists.txt +++ b/enzyme/test/Integration/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(ForwardModeVector) add_subdirectory(ReverseMode) add_subdirectory(BatchMode) add_subdirectory(Sparse) +add_subdirectory(Truncate) # Run regression and unit tests add_lit_testsuite(check-enzyme-integration "Running enzyme integration tests" diff --git a/enzyme/test/Integration/Truncate/CMakeLists.txt b/enzyme/test/Integration/Truncate/CMakeLists.txt new file mode 100644 index 000000000000..bfd5e99064ac --- /dev/null +++ b/enzyme/test/Integration/Truncate/CMakeLists.txt @@ -0,0 +1,9 @@ +# Run regression and unit tests +add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme batch mode integration tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-integration-truncate PROPERTIES FOLDER "Tests") + diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp new file mode 100644 index 000000000000..749d8fded7c1 --- /dev/null +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -0,0 +1,111 @@ +// COM: %clang -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// COM: %clang -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// COM: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - + +#include + +#include "../test_utils.h" + +#define N 10 + +double simple_add(double a, double b) { + return a + b; +} +double intrinsics(double a, double b) { + return sqrt(a) * pow(b, 2); +} +// TODO +double constt(double a, double b) { + return 2; +} +double compute(double *A, double *B, double *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] * 2; + } + return C[0]; +} + +typedef double (*fty)(double *, double *, double *, int); + +typedef double (*fty2)(double, double); + +extern fty __enzyme_truncate_func_2(...); +extern fty2 __enzyme_truncate_func(...); +extern double __enzyme_truncate_value(...); +extern double __enzyme_expand_value(...); + +#define FROM 64 +#define TO 32 + +#define TEST(F) do { + + +int main() { + + { + double a = 1; + APPROX_EQ( + __enzyme_expand_value( + __enzyme_truncate_value(a, FROM, TO) , FROM, TO), + a, 1e-10); + } + + { + double a = 2; + double b = 3; + double truth = simple_add(a, b); + a = __enzyme_truncate_value(a, FROM, TO); + b = __enzyme_truncate_value(b, FROM, TO); + double trunc = __enzyme_expand_value(__enzyme_truncate_func(simple_add, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + { + double a = 2; + double b = 3; + double truth = intrinsics(a, b); + a = __enzyme_truncate_value(a, FROM, TO); + b = __enzyme_truncate_value(b, FROM, TO); + double trunc = __enzyme_expand_value(__enzyme_truncate_func(intrinsics, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + // { + // double a = 2; + // double b = 3; + // double truth = intrinsics(a, b); + // a = __enzyme_truncate_value(a, FROM, TO); + // b = __enzyme_truncate_value(b, FROM, TO); + // double trunc = __enzyme_expand_value(__enzyme_truncate_func(constt, FROM, TO)(a, b), FROM, TO); + // APPROX_EQ(trunc, truth, 1e-5); + // } + + // double A[N]; + // double B[N]; + // double C[N]; + // double D[N]; + + + // for (int i = 0; i < N; i++) { + // A[i] = 1 + i % 5; + // B[i] = 1 + i % 3; + // } + + // compute(A, B, D, N); + + // for (int i = 0; i < N; i++) { + // A[i] = __enzyme_truncate_value(A[i], 64, 32); + // B[i] = __enzyme_truncate_value(B[i], 64, 32); + // } + + // __enzyme_truncate_func_2(compute, 64, 32)(A, B, C, N); + + // for (int i = 0; i < N; i++) { + // C[i] = __enzyme_expand_value(C[i], 64, 32); + // } + + // for (int i = 0; i < N; i++) { + // printf("%d\n", i); + // APPROX_EQ(D[i], C[i], 1e-5); + // } + +} From 534f79fa18134e5f1882de4fe4e1782b17952250 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Fri, 9 Feb 2024 15:46:32 +0900 Subject: [PATCH 020/106] Add a float op only truncation mode (#1651) * add TODO FP trunc ext conversion insts * WIP value trunc mode * WIP * Fixes * Remove alloc from op trunc * Support vector types * test * Add constructor for FloatRepr * Update CMakeLists.txt * Update simple.cpp --- enzyme/Enzyme/Enzyme.cpp | 56 ++++-- enzyme/Enzyme/EnzymeLogic.cpp | 184 +++++++++++------- enzyme/Enzyme/EnzymeLogic.h | 76 +++++++- enzyme/test/Enzyme/Truncate/cmp.ll | 23 ++- enzyme/test/Enzyme/Truncate/intrinsic.ll | 29 ++- enzyme/test/Enzyme/Truncate/select.ll | 22 ++- enzyme/test/Enzyme/Truncate/simple.ll | 31 ++- enzyme/test/Enzyme/Truncate/value.ll | 8 +- .../test/Integration/Truncate/CMakeLists.txt | 3 +- enzyme/test/Integration/Truncate/simple.cpp | 90 +++++---- 10 files changed, 371 insertions(+), 151 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index c9b73e598222..99c429efa08e 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1314,7 +1314,20 @@ class EnzymeBase { return type_args; } - bool HandleTruncateFunc(CallInst *CI) { + static FloatRepresentation getDefaultFloatRepr(unsigned width) { + switch (width) { + case 16: + return FloatRepresentation(5, 10); + case 32: + return FloatRepresentation(8, 23); + case 64: + return FloatRepresentation(11, 52); + default: + llvm_unreachable("Invalid float width"); + } + }; + + bool HandleTruncateFunc(CallInst *CI, TruncateMode mode) { IRBuilder<> Builder(CI); Function *F = parseFunctionParameter(CI); if (!F) @@ -1331,8 +1344,9 @@ class EnzymeBase { assert(Cto); RequestContext context(CI, &Builder); llvm::Value *res = Logic.CreateTruncateFunc( - context, F, (unsigned)Cfrom->getValue().getZExtValue(), - (unsigned)Cto->getValue().getZExtValue()); + context, F, + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), mode); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); @@ -1356,8 +1370,10 @@ class EnzymeBase { auto Addr = CI->getArgOperand(0); RequestContext context(CI, &Builder); bool res = Logic.CreateTruncateValue( - context, Addr, (unsigned)Cfrom->getValue().getZExtValue(), - (unsigned)Cto->getValue().getZExtValue(), isTruncate); + context, Addr, + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), + isTruncate); if (!res) return false; return true; @@ -2096,7 +2112,8 @@ class EnzymeBase { MapVector toVirtual; MapVector toSize; SmallVector toBatch; - SmallVector toTruncateFunc; + SmallVector toTruncateFuncMem; + SmallVector toTruncateFuncOp; SmallVector toTruncateValue; SmallVector toExpandValue; MapVector toProbProg; @@ -2408,7 +2425,8 @@ class EnzymeBase { bool virtualCall = false; bool sizeOnly = false; bool batch = false; - bool truncateFunc = false; + bool truncateFuncOp = false; + bool truncateFuncMem = false; bool truncateValue = false; bool expandValue = false; bool probProg = false; @@ -2440,13 +2458,16 @@ class EnzymeBase { } else if (Fn->getName().contains("__enzyme_batch")) { enableEnzyme = true; batch = true; - } else if (Fn->getName().contains("__enzyme_truncate_func")) { + } else if (Fn->getName().contains("__enzyme_truncate_mem_func")) { enableEnzyme = true; - truncateFunc = true; - } else if (Fn->getName().contains("__enzyme_truncate_value")) { + truncateFuncMem = true; + } else if (Fn->getName().contains("__enzyme_truncate_op_func")) { + enableEnzyme = true; + truncateFuncOp = true; + } else if (Fn->getName().contains("__enzyme_truncate_mem_value")) { enableEnzyme = true; truncateValue = true; - } else if (Fn->getName().contains("__enzyme_expand_value")) { + } else if (Fn->getName().contains("__enzyme_expand_mem_value")) { enableEnzyme = true; expandValue = true; } else if (Fn->getName().contains("__enzyme_likelihood")) { @@ -2506,8 +2527,10 @@ class EnzymeBase { toSize[CI] = derivativeMode; else if (batch) toBatch.push_back(CI); - else if (truncateFunc) - toTruncateFunc.push_back(CI); + else if (truncateFuncOp) + toTruncateFuncOp.push_back(CI); + else if (truncateFuncMem) + toTruncateFuncMem.push_back(CI); else if (truncateValue) toTruncateValue.push_back(CI); else if (expandValue) @@ -2605,8 +2628,11 @@ class EnzymeBase { for (auto call : toBatch) { HandleBatch(call); } - for (auto call : toTruncateFunc) { - HandleTruncateFunc(call); + for (auto call : toTruncateFuncMem) { + HandleTruncateFunc(call, TruncMem); + } + for (auto call : toTruncateFuncOp) { + HandleTruncateFunc(call, TruncOp); } for (auto call : toTruncateValue) { HandleTruncateValue(call, true); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index dc5e511b73fa..468c4e5ccba2 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -27,8 +27,10 @@ // primal pass. // //===----------------------------------------------------------------------===// +#include "EnzymeLogic.h" #include "ActivityAnalysis.h" #include "AdjointGenerator.h" +#include "EnzymeLogic.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/Intrinsics.h" @@ -4809,23 +4811,31 @@ Function *EnzymeLogic::CreateForwardDiff( return nf; } -static Type *getTypeForWidth(LLVMContext &ctx, unsigned width) { - switch (width) { - default: - return llvm::Type::getIntNTy(ctx, width); - case 64: - return llvm::Type::getDoubleTy(ctx); - case 32: - return llvm::Type::getFloatTy(ctx); - case 16: - return llvm::Type::getHalfTy(ctx); - } +static Value *floatValTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, + FloatRepresentation to) { + Type *toTy = to.getType(B.getContext()); + if (auto vty = dyn_cast(v->getType())) + toTy = VectorType::get(toTy, vty->getElementCount()); + return B.CreateFPTrunc(v, toTy, "enzyme_trunc"); +} + +static Value *floatValExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, FloatRepresentation to) { + Type *fromTy = from.getBuiltinType(B.getContext()); + if (auto vty = dyn_cast(v->getType())) + fromTy = VectorType::get(fromTy, vty->getElementCount()); + return B.CreateFPExt(v, fromTy, "enzyme_exp"); } -static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - unsigned fromwidth, unsigned towidth) { - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); - Type *toTy = getTypeForWidth(B.getContext(), towidth); +static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, + FloatRepresentation to) { + if (isa(v->getType())) + report_fatal_error("vector operations not allowed in mem trunc mode"); + + Type *fromTy = from.getBuiltinType(B.getContext()); + Type *toTy = to.getType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); B.CreateStore( @@ -4834,13 +4844,16 @@ static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, toTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(toTy))); } -static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - unsigned fromwidth, unsigned towidth) { - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); +static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatRepresentation from, FloatRepresentation to) { + if (isa(v->getType())) + report_fatal_error("vector operations not allowed in mem trunc mode"); + + Type *fromTy = from.getBuiltinType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); - auto c0 = - Constant::getNullValue(llvm::Type::getIntNTy(B.getContext(), fromwidth)); + auto c0 = Constant::getNullValue( + llvm::Type::getIntNTy(B.getContext(), from.getTypeWidth())); B.CreateStore( c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); B.CreateStore( @@ -4852,27 +4865,38 @@ static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, class TruncateGenerator : public llvm::InstVisitor { private: ValueToValueMapTy &originalToNewFn; - unsigned fromwidth; - unsigned towidth; + FloatRepresentation from; + FloatRepresentation to; Type *fromType; Type *toType; Function *oldFunc; Function *newFunc; AllocaInst *tmpBlock; + TruncateMode mode; EnzymeLogic &Logic; public: - TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, - unsigned towidth, Function *oldFunc, Function *newFunc, + TruncateGenerator(ValueToValueMapTy &originalToNewFn, + FloatRepresentation from, FloatRepresentation to, + Function *oldFunc, Function *newFunc, TruncateMode mode, EnzymeLogic &Logic) - : originalToNewFn(originalToNewFn), fromwidth(fromwidth), - towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { + : originalToNewFn(originalToNewFn), from(from), to(to), oldFunc(oldFunc), + newFunc(newFunc), mode(mode), Logic(Logic) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - fromType = getTypeForWidth(B.getContext(), fromwidth); - toType = getTypeForWidth(B.getContext(), towidth); + fromType = from.getBuiltinType(B.getContext()); + toType = to.getType(B.getContext()); - tmpBlock = B.CreateAlloca(fromType); + if (mode == TruncMem) + tmpBlock = B.CreateAlloca(fromType); + else + tmpBlock = nullptr; + } + + void checkHandled(llvm::Instruction &inst) { + // if (all_of(inst.getOperandList(), + // [&](Use *use) { return use->get()->getType() == fromType; })) + // todo(inst); } void visitInstruction(llvm::Instruction &inst) { @@ -4887,7 +4911,7 @@ class TruncateGenerator : public llvm::InstVisitor { break; } - todo(inst); + checkHandled(inst); } Type *getFromType() { return fromType; } @@ -4895,11 +4919,25 @@ class TruncateGenerator : public llvm::InstVisitor { Type *getToType() { return toType; } Value *truncate(IRBuilder<> &B, Value *v) { - return floatTruncate(B, v, tmpBlock, fromwidth, towidth); + switch (mode) { + case TruncMem: + return floatMemTruncate(B, v, tmpBlock, from, to); + case TruncOp: + return floatValTruncate(B, v, tmpBlock, from, to); + default: + llvm_unreachable("Unknown trunc mode"); + } } Value *expand(IRBuilder<> &B, Value *v) { - return floatExpand(B, v, tmpBlock, fromwidth, towidth); + switch (mode) { + case TruncMem: + return floatMemExpand(B, v, tmpBlock, from, to); + case TruncOp: + return floatValExpand(B, v, tmpBlock, from, to); + default: + llvm_unreachable("Unknown trunc mode"); + } } void todo(llvm::Instruction &I) { @@ -4967,17 +5005,25 @@ class TruncateGenerator : public llvm::InstVisitor { return; } void visitSelectInst(llvm::SelectInst &SI) { - auto newI = getNewFromOriginal(&SI); - IRBuilder<> B(newI); - auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); - auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); - auto nres = cast( - B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - return; + switch (mode) { + case TruncMem: { + auto newI = getNewFromOriginal(&SI); + IRBuilder<> B(newI); + auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); + auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); + auto nres = cast( + B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + return; + } + case TruncOp: + return; + default: + llvm_unreachable(""); + } } void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } @@ -5005,7 +5051,7 @@ class TruncateGenerator : public llvm::InstVisitor { return; } - if (towidth == 32 || towidth == 16 || towidth == 64) { + if (to.getBuiltinType(BO.getContext())) { auto newI = getNewFromOriginal(&BO); IRBuilder<> B(newI); auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); @@ -5197,7 +5243,7 @@ class TruncateGenerator : public llvm::InstVisitor { Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) - return Logic.CreateTruncateFunc(ctx, F, fromwidth, towidth); + return Logic.CreateTruncateFunc(ctx, F, from, to, mode); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; @@ -5224,19 +5270,26 @@ class TruncateGenerator : public llvm::InstVisitor { newCall->setCalledOperand(val); return; } + void visitFPTruncInst(FPTruncInst &I) { return; } + void visitFPExtInst(FPExtInst &I) { return; } + void visitFPToUIInst(FPToUIInst &I) { return; } + void visitFPToSIInst(FPToSIInst &I) { return; } + void visitUIToFPInst(UIToFPInst &I) { return; } + void visitSIToFPInst(SIToFPInst &I) { return; } }; bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, - unsigned fromwidth, unsigned towidth, - bool isTruncate) { + FloatRepresentation from, + FloatRepresentation to, bool isTruncate) { assert(context.req && context.ip); - if (fromwidth == towidth) { + if (from == to) { + context.req->replaceAllUsesWith(context.req->getOperand(0)); context.req->eraseFromParent(); return true; } - if (fromwidth < towidth) { + if (from < to) { std::string s; llvm::raw_string_ostream ss(s); ss << "Cannot truncate into a large width\n"; @@ -5250,16 +5303,15 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, } IRBuilderBase &B = *context.ip; - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); - Type *toTy = getTypeForWidth(B.getContext(), towidth); + Type *fromTy = from.getBuiltinType(B.getContext()); + Type *toTy = to.getType(B.getContext()); Value *converted = nullptr; if (isTruncate) - converted = - floatExpand(B, B.CreateFPTrunc(v, toTy), nullptr, fromwidth, towidth); + converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, from, to); else converted = - B.CreateFPExt(floatTruncate(B, v, nullptr, fromwidth, towidth), fromTy); + B.CreateFPExt(floatMemTruncate(B, v, nullptr, from, to), fromTy); assert(converted); context.req->replaceAllUsesWith(converted); @@ -5270,12 +5322,13 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm::Function *totrunc, - unsigned fromwidth, - unsigned towidth) { - if (fromwidth == towidth) + FloatRepresentation from, + FloatRepresentation to, + TruncateMode mode) { + if (from == to) return totrunc; - TruncateCacheKey tup(totrunc, fromwidth, towidth); + TruncateCacheKey tup(totrunc, from, to, mode); if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { return TruncateCachedFunctions.find(tup)->second; } @@ -5290,11 +5343,12 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, Type *NewTy = totrunc->getReturnType(); FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); - Function *NewF = - Function::Create(FTy, totrunc->getLinkage(), - "trunc_" + std::to_string(fromwidth) + "_" + - std::to_string(towidth) + totrunc->getName(), - totrunc->getParent()); + std::string truncName = std::string("__enzyme_done_truncate_") + + (mode == TruncMem ? "mem" : "op") + "_func_" + + from.to_string() + "_" + to.to_string() + "_" + + totrunc->getName().str(); + Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, + totrunc->getParent()); NewF->setLinkage(Function::LinkageTypes::InternalLinkage); @@ -5327,7 +5381,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } - if (fromwidth < towidth) { + if (from < to) { std::string s; llvm::raw_string_ostream ss(s); ss << "Cannot truncate into a large width\n"; @@ -5375,7 +5429,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - TruncateGenerator handle(originalToNewFn, fromwidth, towidth, totrunc, NewF, + TruncateGenerator handle(originalToNewFn, from, to, totrunc, NewF, mode, *this); for (auto &BB : *totrunc) for (auto &I : BB) diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 2f1ac9fde496..6f311fcc85da 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -253,6 +253,74 @@ struct RequestContext { : req(req), ip(ip) {} }; +[[maybe_unused]] static llvm::Type * +getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { + switch (width) { + default: + if (builtinFloat) + llvm::report_fatal_error("Invalid float width requested"); + else + llvm::report_fatal_error( + "Truncation to non builtin float width unsupported"); + case 64: + return llvm::Type::getDoubleTy(ctx); + case 32: + return llvm::Type::getFloatTy(ctx); + case 16: + return llvm::Type::getHalfTy(ctx); + } +} + +enum TruncateMode { TruncMem, TruncOp }; + +struct FloatRepresentation { + // |_|__________|_________________| + // ^ ^ ^ + // sign bit exponent significand + // + // value = (sign) * significand * 2 ^ exponent + unsigned exponentWidth; + unsigned significandWidth; + + FloatRepresentation(unsigned e, unsigned s) + : exponentWidth(e), significandWidth(s) {} + + unsigned getTypeWidth() const { return 1 + exponentWidth + significandWidth; } + + bool canBeBuiltin() const { + unsigned w = getTypeWidth(); + return (w == 16 && significandWidth == 10) || + (w == 32 && significandWidth == 23) || + (w == 64 && significandWidth == 52); + } + + llvm::Type *getBuiltinType(llvm::LLVMContext &ctx) const { + if (!canBeBuiltin()) + return nullptr; + return getTypeForWidth(ctx, getTypeWidth(), /*builtinFloat=*/true); + } + + llvm::Type *getType(llvm::LLVMContext &ctx) const { + llvm::Type *builtinType = getBuiltinType(ctx); + if (builtinType) + return builtinType; + llvm_unreachable("TODO MPFR"); + } + + bool operator==(const FloatRepresentation &other) const { + return other.exponentWidth == exponentWidth && + other.significandWidth == significandWidth; + } + bool operator<(const FloatRepresentation &other) const { + return std::tuple(exponentWidth, significandWidth) < + std::tuple(other.exponentWidth, other.significandWidth); + } + std::string to_string() const { + return std::to_string(getTypeWidth()) + "_" + + std::to_string(significandWidth); + } +}; + class EnzymeLogic { public: PreProcessCache PPC; @@ -511,13 +579,15 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); - using TruncateCacheKey = std::tuple; + using TruncateCacheKey = std::tuple; std::map TruncateCachedFunctions; llvm::Function *CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, - unsigned fromwidth, unsigned towidth); + FloatRepresentation from, + FloatRepresentation to, TruncateMode mode); bool CreateTruncateValue(RequestContext context, llvm::Value *addr, - unsigned fromwidth, unsigned towidth, + FloatRepresentation from, FloatRepresentation to, bool isTruncate); /// Create a traced version of a function diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index e9c61ebd773e..c96efa70660a 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -6,22 +6,28 @@ define i1 @f(double %x, double %y) { ret i1 %res } -declare i1 (double, double)* @__enzyme_truncate_func(...) +declare i1 (double, double)* @__enzyme_truncate_mem_func(...) +declare i1 (double, double)* @__enzyme_truncate_op_func(...) define i1 @tester(double %x, double %y) { entry: - %ptr = call i1 (double, double)* (...) @__enzyme_truncate_func(i1 (double, double)* @f, i64 64, i64 32) + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_mem_func(i1 (double, double)* @f, i64 64, i64 32) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} +define i1 @tester_op(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_op_func(i1 (double, double)* @f, i64 64, i64 32) %res = call i1 %ptr(double %x, double %y) ret i1 %res } ; CHECK: define i1 @tester(double %x, double %y) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call i1 @trunc_64_32f(double %x, double %y) +; CHECK-NEXT: %res = call i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) ; CHECK-NEXT: ret i1 %res -; CHECK-NEXT: } -; CHECK: define internal i1 @trunc_64_32f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -31,4 +37,9 @@ entry: ; CHECK-DAG: %5 = load float, float* %4, align 4 ; CHECK-DAG: %res = fcmp olt float %3, %5 ; CHECK-DAG: ret i1 %res -; CHECK-NEXT:} + +; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %res = fcmp olt float %enzyme_trunc, %enzyme_trunc1 +; CHECK-DAG: ret i1 %res diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index da4457492ce2..99568539c3f3 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -13,16 +13,23 @@ define double @f(double %x, double %y) { ret double %res } -declare double (double, double)* @__enzyme_truncate_func(...) +declare double (double, double)* @__enzyme_truncate_mem_func(...) +declare double (double, double)* @__enzyme_truncate_op_func(...) define double @tester(double %x, double %y) { entry: - %ptr = call double (double, double)* (...) @__enzyme_truncate_func(double (double, double)* @f, i64 64, i64 32) + %ptr = call double (double, double)* (...) @__enzyme_truncate_mem_func(double (double, double)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y) + ret double %res +} +define double @tester2(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y) ret double %res } -; CHECK: define internal double @trunc_64_32f(double %x, double %y) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { ; CHECK-NEXT: %1 = alloca double, align 8 ; CHECK-NEXT: store double %x, double* %1, align 8 ; CHECK-NEXT: %2 = bitcast double* %1 to float* @@ -59,4 +66,18 @@ entry: ; CHECK-NEXT: %20 = load double, double* %1, align 8 ; CHECK-NEXT: call void @llvm.nvvm.barrier0() ; CHECK-NEXT: ret double %20 -; CHECK-NEXT: } + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) +; CHECK-DAG: %enzyme_exp = fpext float %res12 to double +; CHECK-DAG: %enzyme_trunc3 = fptrunc double %x to float +; CHECK-DAG: %res24 = call float @llvm.powi.f32.i16(float %enzyme_trunc3, i16 2) +; CHECK-DAG: %enzyme_exp5 = fpext float %res24 to double +; CHECK-DAG: %enzyme_trunc6 = fptrunc double %enzyme_exp to float +; CHECK-DAG: %enzyme_trunc7 = fptrunc double %enzyme_exp5 to float +; CHECK-DAG: %res = fadd float %enzyme_trunc6, %enzyme_trunc7 +; CHECK-DAG: %enzyme_exp8 = fpext float %res to double +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %enzyme_exp8 diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index 58b4a58ef91b..365d21ab5913 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -6,22 +6,29 @@ define double @f(double %x, double %y, i1 %cond) { ret double %res } -declare double (double, double, i1)* @__enzyme_truncate_func(...) +declare double (double, double, i1)* @__enzyme_truncate_mem_func(...) +declare double (double, double, i1)* @__enzyme_truncate_op_func(...) define double @tester(double %x, double %y, i1 %cond) { entry: - %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_func(double (double, double, i1)* @f, i64 64, i64 32) + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_mem_func(double (double, double, i1)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y, i1 %cond) + ret double %res +} + +define double @tester2(double %x, double %y, i1 %cond) { +entry: + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_op_func(double (double, double, i1)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y, i1 %cond) ret double %res } ; CHECK: define double @tester(double %x, double %y, i1 %cond) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call double @trunc_64_32f(double %x, double %y, i1 %cond) +; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) ; CHECK-NEXT: ret double %res -; CHECK-NEXT: } -; CHECK: define internal double @trunc_64_32f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -36,4 +43,7 @@ entry: ; CHECK-DAG: store float %res, float* %7, align 4 ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: ret double %8 -; CHECK-NEXT: } + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK-DAG: %res = select i1 %cond, double %x, double %y +; CHECK-DAG: ret double %res diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 0f346a26f0d2..19d6cf1f3a23 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -8,22 +8,34 @@ define void @f(double* %x) { ret void } -declare void (double*)* @__enzyme_truncate_func(...) +declare void (double*)* @__enzyme_truncate_mem_func(...) +declare void (double*)* @__enzyme_truncate_op_func(...) define void @tester(double* %data) { entry: - %ptr = call void (double*)* (...) @__enzyme_truncate_func(void (double*)* @f, i64 64, i64 32) + %ptr = call void (double*)* (...) @__enzyme_truncate_mem_func(void (double*)* @f, i64 64, i64 32) + call void %ptr(double* %data) + ret void +} + +define void @tester2(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 32) call void %ptr(double* %data) ret void } ; CHECK: define void @tester(double* %data) ; CHECK-NEXT: entry: -; CHECK-NEXT: call void @trunc_64_32f(double* %data) +; CHECK-NEXT: call void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %data) ; CHECK-NEXT: ret void -; CHECK-NEXT: } -; CHECK: define internal void @trunc_64_32f(double* %x) +; CHECK: define void @tester2(double* %data) { +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %data) +; CHECK-NEXT: ret void + +; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %x) ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: store double %y, double* %1, align 8 @@ -40,3 +52,12 @@ entry: ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: store double %8, double* %x, align 8 ; CHECK-DAG: ret void + +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %x) { +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: %enzyme_trunc = fptrunc double %y to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %m = fmul float %enzyme_trunc, %enzyme_trunc1 +; CHECK-DAG: %enzyme_exp = fpext float %m to double +; CHECK-DAG: store double %enzyme_exp, double* %x, align 8 +; CHECK-DAG: ret void diff --git a/enzyme/test/Enzyme/Truncate/value.ll b/enzyme/test/Enzyme/Truncate/value.ll index 51f00401078d..9f87d00d2173 100644 --- a/enzyme/test/Enzyme/Truncate/value.ll +++ b/enzyme/test/Enzyme/Truncate/value.ll @@ -1,18 +1,18 @@ ; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi ; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi -declare double @__enzyme_truncate_value(double, i64, i64) -declare double @__enzyme_expand_value(double, i64, i64) +declare double @__enzyme_truncate_mem_value(double, i64, i64) +declare double @__enzyme_expand_mem_value(double, i64, i64) define double @expand_tester(double %a, double * %c) { entry: - %b = call double @__enzyme_expand_value(double %a, i64 64, i64 32) + %b = call double @__enzyme_expand_mem_value(double %a, i64 64, i64 32) ret double %b } define double @truncate_tester(double %a) { entry: - %b = call double @__enzyme_truncate_value(double %a, i64 64, i64 32) + %b = call double @__enzyme_truncate_mem_value(double %a, i64 64, i64 32) ret double %b } diff --git a/enzyme/test/Integration/Truncate/CMakeLists.txt b/enzyme/test/Integration/Truncate/CMakeLists.txt index bfd5e99064ac..65187869f1ae 100644 --- a/enzyme/test/Integration/Truncate/CMakeLists.txt +++ b/enzyme/test/Integration/Truncate/CMakeLists.txt @@ -1,5 +1,4 @@ -# Run regression and unit tests -add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme batch mode integration tests" +add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme fp truncation integration tests" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${ENZYME_TEST_DEPS} ARGS -v diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 749d8fded7c1..792366d687e0 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -1,7 +1,7 @@ -// COM: %clang -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// COM: %clang -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - -// COM: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_OP -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_OP -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - #include @@ -21,7 +21,8 @@ double constt(double a, double b) { } double compute(double *A, double *B, double *C, int n) { for (int i = 0; i < n; i++) { - C[i] = A[i] * 2; + C[i] = A[i] * 2; + // C[i] A=[i] * 2 + B[i] * sqrt(A[i]) ; } return C[0]; } @@ -30,10 +31,12 @@ typedef double (*fty)(double *, double *, double *, int); typedef double (*fty2)(double, double); -extern fty __enzyme_truncate_func_2(...); -extern fty2 __enzyme_truncate_func(...); -extern double __enzyme_truncate_value(...); -extern double __enzyme_expand_value(...); +extern fty __enzyme_truncate_mem_func_2(...); +extern fty2 __enzyme_truncate_mem_func(...); +extern fty __enzyme_truncate_op_func_2(...); +extern fty2 __enzyme_truncate_op_func(...); +extern double __enzyme_truncate_mem_value(...); +extern double __enzyme_expand_mem_value(...); #define FROM 64 #define TO 32 @@ -43,11 +46,12 @@ extern double __enzyme_expand_value(...); int main() { + #ifdef TRUNC_MEM { double a = 1; APPROX_EQ( - __enzyme_expand_value( - __enzyme_truncate_value(a, FROM, TO) , FROM, TO), + __enzyme_expand_mem_value( + __enzyme_truncate_mem_value(a, FROM, TO) , FROM, TO), a, 1e-10); } @@ -55,57 +59,61 @@ int main() { double a = 2; double b = 3; double truth = simple_add(a, b); - a = __enzyme_truncate_value(a, FROM, TO); - b = __enzyme_truncate_value(b, FROM, TO); - double trunc = __enzyme_expand_value(__enzyme_truncate_func(simple_add, FROM, TO)(a, b), FROM, TO); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(simple_add, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } { double a = 2; double b = 3; double truth = intrinsics(a, b); - a = __enzyme_truncate_value(a, FROM, TO); - b = __enzyme_truncate_value(b, FROM, TO); - double trunc = __enzyme_expand_value(__enzyme_truncate_func(intrinsics, FROM, TO)(a, b), FROM, TO); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } + #endif // { // double a = 2; // double b = 3; // double truth = intrinsics(a, b); - // a = __enzyme_truncate_value(a, FROM, TO); - // b = __enzyme_truncate_value(b, FROM, TO); - // double trunc = __enzyme_expand_value(__enzyme_truncate_func(constt, FROM, TO)(a, b), FROM, TO); + // a = __enzyme_truncate_mem_value(a, FROM, TO); + // b = __enzyme_truncate_mem_value(b, FROM, TO); + // double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); // APPROX_EQ(trunc, truth, 1e-5); // } - // double A[N]; - // double B[N]; - // double C[N]; - // double D[N]; + #ifdef TRUNC_OP + { + double A[N]; + double B[N]; + double C[N]; + double D[N]; - // for (int i = 0; i < N; i++) { - // A[i] = 1 + i % 5; - // B[i] = 1 + i % 3; - // } + for (int i = 0; i < N; i++) { + A[i] = 1 + i % 5; + B[i] = 1 + i % 3; + } - // compute(A, B, D, N); + compute(A, B, D, N); - // for (int i = 0; i < N; i++) { - // A[i] = __enzyme_truncate_value(A[i], 64, 32); - // B[i] = __enzyme_truncate_value(B[i], 64, 32); - // } + // for (int i = 0; i < N; i++) { + // A[i] = __enzyme_truncate_mem_value(A[i], 64, 32); + // B[i] = __enzyme_truncate_mem_value(B[i], 64, 32); + // } - // __enzyme_truncate_func_2(compute, 64, 32)(A, B, C, N); + __enzyme_truncate_op_func_2(compute, 64, 32)(A, B, C, N); - // for (int i = 0; i < N; i++) { - // C[i] = __enzyme_expand_value(C[i], 64, 32); - // } + // for (int i = 0; i < N; i++) { + // C[i] = __enzyme_expand_mem_value(C[i], 64, 32); + // } - // for (int i = 0; i < N; i++) { - // printf("%d\n", i); - // APPROX_EQ(D[i], C[i], 1e-5); - // } + for (int i = 0; i < N; i++) { + APPROX_EQ(D[i], C[i], 1e-5); + } + } + #endif } From 60fd0a1635ffde1a962d266a6d77e4fa50884754 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 09:29:57 -0500 Subject: [PATCH 021/106] MDTuple better error handler (#1684) --- enzyme/Enzyme/GradientUtils.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 0b6947887f6f..ec15cd2e76ff 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5435,10 +5435,21 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, if (!isa(md)) { llvm::errs() << *arg << "\n"; llvm::errs() << *md << "\n"; - assert(0 && "cannot compute with global variable that doesn't have " - "marked shadow global"); - report_fatal_error("cannot compute with global variable that doesn't " - "have marked shadow global (metadata incorrect type)"); + std::string s; + llvm::raw_string_ostream ss(s); + ss << "cannot compute with global variable that doesn't have marked " + "shadow global as mdtuple\n"; + ss << *arg << "\n"; + ss << " md: " << *md << "\n"; + if (CustomErrorHandler) { + return unwrap(CustomErrorHandler(ss.str().c_str(), wrap(arg), + ErrorType::NoShadow, this, nullptr, + wrap(&BuilderM))); + } else { + EmitFailure("InvertGlobal", BuilderM.getCurrentDebugLocation(), oldFunc, + ss.str()); + } + return UndefValue::get(getShadowType(arg->getType())); } auto md2 = cast(md); assert(md2->getNumOperands() == 1); From e3c62c80cd4932af73f3e4d876940a44c985da4c Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 9 Feb 2024 21:34:11 +0100 Subject: [PATCH 022/106] Remove unnecessary uses of templating (#1681) * remove unnecessary uses of templating * fix build --- .devcontainer/devcontainer.json | 2 +- enzyme/Enzyme/AdjointGenerator.h | 16 +++++------ enzyme/Enzyme/CallDerivatives.cpp | 25 +++--------------- enzyme/Enzyme/EnzymeLogic.cpp | 38 +++++++++++++-------------- enzyme/Enzyme/TypeAnalysis/BaseType.h | 3 ++- 5 files changed, 33 insertions(+), 51 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 368b7933e347..d201b73d532d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -2,7 +2,7 @@ // available llvm versions: [9, 10, 11, 12, 13, 14, 15] { "name": "Enzyme", - "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-20-llvm-12:latest", + "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-22-llvm-16:latest", "mounts": [ "source=enzyme-bashhistory,target=/commandhistory,type=volume", "source=enzyme-extensions,target=/home/vscode/.vscode-server/extensions,type=volume", diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index debac4dd496f..de9bbf8e38a6 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -47,9 +47,7 @@ #define DEBUG_TYPE "enzyme" // Helper instruction visitor that generates adjoints -template -class AdjointGenerator - : public llvm::InstVisitor> { +class AdjointGenerator : public llvm::InstVisitor { private: // Type of code being generated (forward, reverse, or both) const DerivativeMode Mode; @@ -63,7 +61,7 @@ class AdjointGenerator const std::map> overwritten_args_map; const llvm::SmallPtrSetImpl *returnuses; - AugmentedReturnType augmentedReturn; + const AugmentedReturn *augmentedReturn; const std::map *replacedReturns; const llvm::SmallPtrSetImpl &unnecessaryValues; @@ -83,7 +81,7 @@ class AdjointGenerator const std::map> overwritten_args_map, const llvm::SmallPtrSetImpl *returnuses, - AugmentedReturnType augmentedReturn, + const AugmentedReturn *augmentedReturn, const std::map *replacedReturns, const llvm::SmallPtrSetImpl &unnecessaryValues, const llvm::SmallPtrSetImpl @@ -3040,8 +3038,8 @@ class AdjointGenerator } if (!vd.isKnownPastPointer()) { if (looseTypeAnalysis) { - if (auto CI = dyn_cast(MS.getOperand(0))) { #if LLVM_VERSION_MAJOR < 17 + if (auto CI = dyn_cast(MS.getOperand(0))) { if (auto PT = dyn_cast(CI->getSrcTy())) { auto ET = PT->getPointerElementType(); while (1) { @@ -3070,8 +3068,8 @@ class AdjointGenerator goto known; } } -#endif } +#endif if (auto gep = dyn_cast(MS.getOperand(0))) { if (auto AT = dyn_cast(gep->getSourceElementType())) { if (AT->getElementType()->isIntegerTy()) { @@ -3312,8 +3310,8 @@ class AdjointGenerator if (!vd.isKnownPastPointer()) { if (looseTypeAnalysis) { for (auto val : {orig_dst, orig_src}) { - if (auto CI = dyn_cast(val)) { #if LLVM_VERSION_MAJOR < 17 + if (auto CI = dyn_cast(val)) { if (auto PT = dyn_cast(CI->getSrcTy())) { auto ET = PT->getPointerElementType(); while (1) { @@ -3342,8 +3340,8 @@ class AdjointGenerator goto known; } } -#endif } +#endif if (auto gep = dyn_cast(val)) { if (auto AT = dyn_cast(gep->getSourceElementType())) { if (AT->getElementType()->isIntegerTy()) { diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 02719a9a4660..c8b2cdaa4e4b 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -32,10 +32,8 @@ extern "C" { void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr; } -template -void AdjointGenerator::handleMPI(llvm::CallInst &call, - llvm::Function *called, - llvm::StringRef funcName) { +void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, + llvm::StringRef funcName) { using namespace llvm; assert(called); @@ -2214,8 +2212,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm_unreachable("Unhandled MPI FUNCTION"); } -template -bool AdjointGenerator::handleKnownCallDerivatives( +bool AdjointGenerator::handleKnownCallDerivatives( CallInst &call, Function *called, StringRef funcName, const std::vector &overwritten_args, CallInst *const newCall) { bool subretused = false; @@ -3231,7 +3228,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( llvm_unreachable("Unknown allocation to upgrade"); Size = gutils->getNewFromOriginal(Size); - if (auto CI = dyn_cast(Size)) { + if (isa(Size)) { B.SetInsertPoint(gutils->inversionAllocs); } Type *elTy = Type::getInt8Ty(call.getContext()); @@ -4155,17 +4152,3 @@ bool AdjointGenerator::handleKnownCallDerivatives( return false; } - -template bool AdjointGenerator::handleKnownCallDerivatives( - CallInst &call, Function *called, StringRef funcName, - const std::vector &overwritten_args, CallInst *const newCall); -template bool -AdjointGenerator::handleKnownCallDerivatives( - CallInst &call, Function *called, StringRef funcName, - const std::vector &overwritten_args, CallInst *const newCall); - -template void -AdjointGenerator::handleMPI(CallInst &call, Function *called, - StringRef funcName); -template void AdjointGenerator::handleMPI( - CallInst &call, Function *called, StringRef funcName); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 468c4e5ccba2..4d6b0fef39d9 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2416,12 +2416,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( } } - AdjointGenerator maker( - DerivativeMode::ReverseModePrimal, gutils, constant_args, retType, - getIndex, overwritten_args_map, &returnuses, - &AugmentedCachedFunctions.find(tup)->second, nullptr, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, - nullptr); + AdjointGenerator maker(DerivativeMode::ReverseModePrimal, gutils, + constant_args, retType, getIndex, overwritten_args_map, + &returnuses, + &AugmentedCachedFunctions.find(tup)->second, nullptr, + unnecessaryValues, unnecessaryInstructions, + unnecessaryStores, guaranteedUnreachable, nullptr); for (BasicBlock &oBB : *gutils->oldFunc) { auto term = oBB.getTerminator(); @@ -4174,12 +4174,12 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } } - AdjointGenerator maker( - key.mode, gutils, key.constant_args, key.retType, getIndex, - overwritten_args_map, - /*returnuses*/ nullptr, augmenteddata, &replacedReturns, - unnecessaryValues, unnecessaryInstructions, unnecessaryStores, - guaranteedUnreachable, dretAlloca); + AdjointGenerator maker(key.mode, gutils, key.constant_args, key.retType, + getIndex, overwritten_args_map, + /*returnuses*/ nullptr, augmenteddata, + &replacedReturns, unnecessaryValues, + unnecessaryInstructions, unnecessaryStores, + guaranteedUnreachable, dretAlloca); for (BasicBlock &oBB : *gutils->oldFunc) { // Don't create derivatives for code that results in termination @@ -4643,7 +4643,7 @@ Function *EnzymeLogic::CreateForwardDiff( SmallPtrSet unnecessaryInstructions; SmallPtrSet unnecessaryStores; - AdjointGenerator *maker; + AdjointGenerator *maker; std::unique_ptr> can_modref_map; if (mode == DerivativeMode::ForwardModeSplit) { @@ -4683,7 +4683,7 @@ Function *EnzymeLogic::CreateForwardDiff( calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, unnecessaryInstructions, gutils, TLI); - maker = new AdjointGenerator( + maker = new AdjointGenerator( mode, gutils, constant_args, retType, getIndex, overwritten_args_map, /*returnuses*/ nullptr, augmenteddata, nullptr, unnecessaryValues, unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, @@ -4736,11 +4736,11 @@ Function *EnzymeLogic::CreateForwardDiff( calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, unnecessaryInstructions, gutils, TLI); - maker = new AdjointGenerator( - mode, gutils, constant_args, retType, nullptr, {}, - /*returnuses*/ nullptr, nullptr, nullptr, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, - nullptr); + maker = + new AdjointGenerator(mode, gutils, constant_args, retType, nullptr, {}, + /*returnuses*/ nullptr, nullptr, nullptr, + unnecessaryValues, unnecessaryInstructions, + unnecessaryStores, guaranteedUnreachable, nullptr); } for (BasicBlock &oBB : *gutils->oldFunc) { diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index f73948a703ad..2f5275c4b6f8 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -25,6 +25,7 @@ #ifndef ENZYME_TYPE_ANALYSIS_BASE_TYPE_H #define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1 +#include "llvm/ADT/StringRef.h" #include /// Categories of potential types @@ -59,7 +60,7 @@ static inline std::string to_string(BaseType t) { } /// Convert string to BaseType -template static inline BaseType parseBaseType(T str) { +static inline BaseType parseBaseType(llvm::StringRef str) { if (str == "Integer") return BaseType::Integer; if (str == "Float") From abb4ee5fbf292d2e31fac2bc97bbf77aae3c0ba3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 20:04:37 -0500 Subject: [PATCH 023/106] [MLIR] fix mlir activity arg parsing (#1678) --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 5 ++++- enzyme/test/MLIR/ForwardMode/wrap.mlir | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 235a061c6088..d96c18a68c09 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -64,8 +64,11 @@ struct DifferentiateWrapperPass constants.push_back(DIFFE_TYPE::DUP_NONEED); else if (str == "enzyme_out") constants.push_back(DIFFE_TYPE::OUT_DIFF); - else + else { + llvm::errs() << "unknown argument activity to parse, found: '" << str + << "'\n"; assert(0 && " unknown constant"); + } } DIFFE_TYPE retType = retTy.getValue(); diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir index 7cb0eb82bf4b..7cef99c691f8 100644 --- a/enzyme/test/MLIR/ForwardMode/wrap.mlir +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup, mode=ForwardMode" %s | FileCheck %s +// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup mode=ForwardMode" %s | FileCheck %s module { func.func @square(%x : f64) -> f64{ From 49cc903a7c4d4f3a781a9307db4c63b68b3ddd97 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 9 Feb 2024 20:30:29 -0500 Subject: [PATCH 024/106] add a little bit more docs (#1664) * add a little bit more docs * update * adressing feedback * fmt * fmt --- enzyme/Enzyme/CApi.h | 11 +++++++---- enzyme/Enzyme/EnzymeLogic.h | 34 +++++++++++++++++++--------------- enzyme/Enzyme/Utils.h | 11 +++++++---- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index d8e04724631a..fe3760457926 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -113,11 +113,14 @@ struct CFnTypeInfo { }; typedef enum { - DFT_OUT_DIFF = 0, // add differential to an output struct - DFT_DUP_ARG = 1, // duplicate the argument and store differential inside - DFT_CONSTANT = 2, // no differential + DFT_OUT_DIFF = 0, // add differential to an output struct. Only for scalar + // values in ReverseMode variants. + DFT_DUP_ARG = 1, // duplicate the argument and store differential inside. + // For references, pointers, or integers in ReverseMode + // variants. For all types in ForwardMode variants. + DFT_CONSTANT = 2, // no differential. Usable everywhere. DFT_DUP_NONEED = 3 // duplicate this argument and store differential inside, - // but don't need the forward + // but don't need the forward. Same as DUP_ARG otherwise. } CDIFFE_TYPE; typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 6f311fcc85da..3923f205037e 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -132,6 +132,21 @@ class AugmentedReturn { isComplete(false) {} }; +/// \p todiff is the function to differentiate +/// \p retType is the activity info of the return. +/// Only allowed to be DUP_ARG or CONSTANT. DUP_NONEED is not allowed, +/// set returnValue to false instead. +/// \p constant_args is the activity info of the arguments +/// \p returnValue is whether the primal's return should also be returned. +/// \p dretUsed is whether the shadow return value should also be returned. +/// Only allowed to be true if retType is CDIFFE_TYPE::DUP_ARG. +/// \p additionalArg is the type (or null) of an additional type in the +/// signature to hold the tape. +/// \p typeInfo is the type info information about the calling context +/// \p _overwritten_args marks whether an argument may be overwritten +/// before loads in the generated function (and thus cannot be cached). +/// \p AtomicAdd is whether to perform all adjoint +/// updates to memory in an atomic way struct ReverseCacheKey { llvm::Function *todiff; DIFFE_TYPE retType; @@ -427,9 +442,10 @@ class EnzymeLogic { /// \p returnUsed is whether the primal's return should also be returned /// \p typeInfo is the type info information about the calling context /// \p _overwritten_args marks whether an argument may be rewritten before - /// loads in the generated function (and thus cannot be cached). \p - /// forceAnonymousTape forces the tape to be an i8* rather than the true tape - /// structure \p AtomicAdd is whether to perform all adjoint updates to + /// loads in the generated function (and thus cannot be cached). + /// \p forceAnonymousTape forces the tape to be an i8* rather than the true + /// tape structure + /// \p AtomicAdd is whether to perform all adjoint updates to /// memory in an atomic way const AugmentedReturn &CreateAugmentedPrimal( RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, @@ -521,20 +537,8 @@ class EnzymeLogic { /// Create the reverse pass, or combined forward+reverse derivative function. /// \p context the instruction which requested this derivative (or null). - /// \p todiff is the function to differentiate - /// \p retType is the activity info of the return - /// \p constant_args is the activity info of the arguments - /// \p returnValue is whether the primal's return should also be returned - /// \p dretUsed is whether the shadow return value should also be returned - /// \p additionalArg is the type (or null) of an additional type in the - /// signature to hold the tape. - /// \p typeInfo is the type info information about the calling context - /// \p _overwritten_args marks whether an argument may be rewritten - /// before loads in the generated function (and thus cannot be cached). /// \p augmented is the data structure created by prior call to an /// augmented forward pass - /// \p AtomicAdd is whether to perform all adjoint - /// updates to memory in an atomic way llvm::Function *CreatePrimalAndGradient(RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 7b46fd9a2e83..f4f72f8cf52e 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -318,11 +318,14 @@ enum class ReturnType { /// Potential differentiable argument classifications enum class DIFFE_TYPE { - OUT_DIFF = 0, // add differential to an output struct - DUP_ARG = 1, // duplicate the argument and store differential inside - CONSTANT = 2, // no differential + OUT_DIFF = 0, // add differential to an output struct. Only for scalar values + // in ReverseMode variants. + DUP_ARG = 1, // duplicate the argument and store differential inside. + // For references, pointers, or integers in ReverseMode variants. + // For all types in ForwardMode variants. + CONSTANT = 2, // no differential. Usable everywhere. DUP_NONEED = 3 // duplicate this argument and store differential inside, but - // don't need the forward + // don't need the forward. Same as DUP_ARG otherwise. }; enum class BATCH_TYPE { From b5b7d10a6852e9535be768b70d25349339e4cf4f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 20:49:45 -0500 Subject: [PATCH 025/106] Update build_tarballs.jl (#1688) --- .packaging/build_tarballs.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/.packaging/build_tarballs.jl b/.packaging/build_tarballs.jl index 009b48e756e4..0b41f026ba38 100644 --- a/.packaging/build_tarballs.jl +++ b/.packaging/build_tarballs.jl @@ -28,7 +28,7 @@ platforms = expand_cxxstring_abis(supported_platforms(; experimental=true)) script = raw""" cd Enzyme -if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16.asserts* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+17.asserts* ]]; then +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16* ]]; then # LLVM 15 requires macOS SDK 10.14. pushd $WORKSPACE/srcdir/MacOSX10.*.sdk rm -rf /opt/${target}/${target}/sys-root/System @@ -54,11 +54,11 @@ cmake -B build-native -S enzyme -GNinja "${NATIVE_CMAKE_FLAGS[@]}" # Only build blasheaders and tblgen ninja -C build-native -j ${nproc} blasheaders enzyme-tblgen - # 2. Cross-compile CMAKE_FLAGS=() CMAKE_FLAGS+=(-DENZYME_EXTERNAL_SHARED_LIB=ON) CMAKE_FLAGS+=(-DBC_LOAD_HEADER=`pwd`/build-native/BCLoad/gsl/blas_headers.h) +CMAKE_FLAGS+=(-DEnzyme_TABLEGEN=`pwd`/build-native/tools/enzyme-tblgen/enzyme-tblgen) CMAKE_FLAGS+=(-DEnzyme_TABLEGEN_EXE=`pwd`/build-native/tools/enzyme-tblgen/enzyme-tblgen) CMAKE_FLAGS+=(-DENZYME_CLANG=OFF) # RelWithDebInfo for decent performance, with debugability @@ -66,7 +66,11 @@ CMAKE_FLAGS+=(-DCMAKE_BUILD_TYPE=RelWithDebInfo) # Install things into $prefix CMAKE_FLAGS+=(-DCMAKE_INSTALL_PREFIX=${prefix}) # Explicitly use our cmake toolchain file and tell CMake we're cross-compiling -CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN}) +if [[ "${target}" == *mingw* && "${bb_full_target}" == *llvm_version+16* ]]; then + CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN%.*}_clang.cmake) +else + CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN}) +fi CMAKE_FLAGS+=(-DCMAKE_CROSSCOMPILING:BOOL=ON) # Tell CMake where LLVM is CMAKE_FLAGS+=(-DLLVM_DIR="${prefix}/lib/cmake/llvm") @@ -74,10 +78,18 @@ CMAKE_FLAGS+=(-DLLVM_DIR="${prefix}/lib/cmake/llvm") CMAKE_FLAGS+=(-DLLVM_LINK_LLVM_DYLIB=ON) # Build the library CMAKE_FLAGS+=(-DBUILD_SHARED_LIBS=ON) + +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16* ]]; then +if [[ "${target}" == x86_64-apple* ]]; then + CMAKE_FLAGS+=(-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=10.14) +fi +else if [[ "${target}" == x86_64-apple* ]]; then CMAKE_FLAGS+=(-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=10.12) fi +fi +echo ${CMAKE_FLAGS[@]} cmake -B build -S enzyme -GNinja ${CMAKE_FLAGS[@]} ninja -C build -j ${nproc} install @@ -113,7 +125,7 @@ for llvm_version in llvm_versions, llvm_assertions in (false, true) # We don't build LLVM 15 for i686-linux-musl. filter!(p -> !(arch(p) == "i686" && libc(p) == "musl"), platforms) end - + for platform in platforms augmented_platform = deepcopy(platform) augmented_platform[LLVM.platform_name] = LLVM.platform(llvm_version, llvm_assertions) From 56bb0a62f49ee47e5ceb8bdd4785dd392201c91c Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 9 Feb 2024 22:35:27 -0500 Subject: [PATCH 026/106] Demangle and improve nofree (#1687) * Demangle and improve nofree * f --- enzyme/Enzyme/ActivityAnalysis.cpp | 2 + enzyme/Enzyme/EnzymeLogic.cpp | 170 +++++++++++++++++------------ 2 files changed, 102 insertions(+), 70 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 7b251e154f94..7921d3ff65c0 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -395,6 +395,8 @@ static const char *DemangledKnownInactiveFunctionsStartingWith[] = { // libc++ + "std::__1::locale", + "std::__1::ios_base", "std::__1::basic_string", "std::__1::__do_string_hash", "std::__1::hash", diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 4d6b0fef39d9..d60ac934634b 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -60,6 +60,8 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Demangle/Demangle.h" + #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -5888,73 +5890,101 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (isAllocationFunction(F->getName(), TLI)) return F; + // clang-format off + StringSet<> NoFreeDemangles = { + "std::basic_ostream>& std::__ostream_insert >(std::basic_ostream >&)", + "std::basic_ostream>::put(char)", + + "std::basic_filebuf>::open(char const*, std::_Ios_Openmode)", + "std::basic_filebuf>::basic_filebuf()", + "std::basic_filebuf>::close()", + + "std::basic_ios>::clear(std::_Ios_Iostate)", + "std::__detail::_Prime_rehash_policy::_M_need_rehash(unsigned long, unsigned long, unsigned long) const", + + "std::basic_streambuf >::xsputn(char const*, long)", + + "std::basic_ios >::init(std::basic_streambuf >*)", + + "std::_Hash_bytes(void const*, unsigned long, unsigned long)", + "unsigned long std::__1::__do_string_hash(char const*, char const*)", + "std::__1::hash::operator()(char const*) const", + + "std::allocator::allocator()", + "std::allocator::~allocator()", + + + "std::__cxx11::basic_string, std::allocator>::basic_string(char const*, std::allocator const&)", + "std::__cxx11::basic_string, std::allocator>::basic_string(std::__cxx11::basic_string, std::allocator>&&)", + "std::__cxx11::basic_string, std::allocator>::_M_construct(unsigned long, char)", + "std::__cxx11::basic_string, std::allocator>::_M_append(char const*, unsigned long)", + "std::__cxx11::basic_string, std::allocator>::_M_assign(std::__cxx11::basic_string, std::allocator> const&)", + "std::__cxx11::basic_string, std::allocator>::_M_replace(unsigned long, unsigned long, char const*, unsigned long)", + "std::__cxx11::basic_string, std::allocator>::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)", + "std::__cxx11::basic_string, std::allocator>::length() const", + "std::__cxx11::basic_string, std::allocator>::data() const", + "std::__cxx11::basic_string, std::allocator>::size() const", + "std::__cxx11::basic_string, std::allocator>::~basic_string()", + "std::__cxx11::basic_string, std::allocator>::compare(char const*) const", + "std::__cxx11::basic_string, std::allocator>::compare(std::__cxx11::basic_string, std::allocator> const&) const", + "std::__cxx11::basic_string, std::allocator>::reserve(unsigned long)", + + "std::__cxx11::basic_string, std::allocator>::~basic_string()", + "std::__cxx11::basic_stringbuf, std::allocator>::overflow(int)", + "std::__cxx11::basic_stringbuf, std::allocator>::pbackfail(int)", + "std::__cxx11::basic_stringbuf, std::allocator>::underflow()", + "std::__cxx11::basic_stringbuf, std::allocator>::_M_sync(char*, unsigned long, unsigned long)", + + "std::__basic_file::~__basic_file()", + + "std::basic_ostream>::flush()", + "std::basic_streambuf>::xsgetn(char*, long)", + + "std::locale::~locale()", + "std::ios_base::ios_base()", + "std::basic_ostream>& " + "std::basic_ostream " + ">::_M_insert(double)", + + // libc++ + "std::__1::basic_string, std::__1::allocator>::basic_string(std::__1::basic_string, std::__1::allocator> const&)", + "std::__1::basic_string, std::__1::allocator>::~basic_string()", + "std::__1::basic_string, std::__1::allocator>::__init(char const*, unsigned long)", + "std::__1::basic_string, std::__1::allocator>::append(char const*, unsigned long)", + "std::__1::basic_string, std::__1::allocator>::data() const", + "std::__1::basic_ostream>::sentry::sentry(std::__1::basic_ostream>&)", + "std::__1::basic_ostream>::sentry::~sentry()", + "std::__1::ios_base::__set_badbit_and_consider_rethrow()", + "char* std::__1::addressof(char&)", + "char const* std::__1::addressof(char const&)", + "std::__1::random_device::operator()()", + + "std::__1::locale::~locale()", + "std::__1::locale::use_facet(std::__1::locale::id&) const", + "std::__1::ios_base::ios_base()", + "std::__1::ios_base::getloc() const", + "std::__1::ios_base::clear(unsigned int)", + }; + // clang-format on + StringSet<> NoFrees = { - "mpfr_greater_p", - "memchr", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEC1EPKcRKS3_", - "_ZSt16__ostream_insertIcSt11char_traitsIcEERSt13basic_ostreamIT_T0_ES6_" - "PKS3_l", - "_ZNSo3putEc", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE7_M_syncEPcmm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE10_M_" - "replaceEmmPKcm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_appendEPKcm", - "_ZNSt13basic_filebufIcSt11char_traitsIcEE4openEPKcSt13_Ios_Openmode", - "_ZNSt9basic_iosIcSt11char_traitsIcEE5clearESt12_Ios_Iostate", - "_ZNSt13basic_filebufIcSt11char_traitsIcEE5closeEv", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE14_M_replace_" - "auxEmmmc", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE12_M_constructEmc", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7reserveEm", - "time", - "strlen", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareERKS4_", - "_ZNKSt8__detail20_Prime_rehash_policy14_M_need_rehashEmmm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEC1EOS4_", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE6lengthEv", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE4dataEv", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE4sizeEv", - "_ZNKSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE4dataEv" - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEED1Ev", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEED1Ev", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE6__" - "initEPKcm", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEEC1ERKS5_", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_" - "9allocatorIcEEE6appendEPKcm", - "_ZNSt12__basic_fileIcED1Ev", - "__cxa_begin_catch", - "__cxa_end_catch", - "_ZNSo5flushEv", - "compress2", - "_ZNSt6localeD1Ev", - "_ZNSt8ios_baseC2Ev", - "_ZNSo9_M_insertIdEERSoT_", - "malloc_usable_size", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEED1Ev", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareEPKc", - "_ZNSt13basic_filebufIcSt11char_traitsIcEEC1Ev", - "_ZNSt15basic_streambufIcSt11char_traitsIcEE6xsputnEPKcl", - "_ZNSt9basic_iosIcSt11char_traitsIcEE4initEPSt15basic_streambufIcS1_E", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE8overflowEi", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9pbackfailEi", - "_ZNSt15basic_streambufIcSt11char_traitsIcEE6xsgetnEPcl", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9underflowEv", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_assignERKS4_", - "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev", - "_ZSt11_Hash_bytesPKvmm", - "_ZNSt3__116__do_string_hashIPKcEEmT_S3_", - "_ZNKSt3__14hashIPKcEclES2_", - "_ZNSt3__19addressofIcEEPT_RS1_", - "_ZNSt3__19addressofIKcEEPT_RS2_", - "_ZNSt3__113random_deviceclEv", + "mpfr_greater_p", "memchr", "time", "strlen", + "__cxa_begin_catch", "__cxa_end_catch", "compress2", "malloc_usable_size", "MPI_Allreduce", }; if (startsWith(F->getName(), "_ZNSolsE") || NoFrees.count(F->getName())) return F; + std::string demangledName = llvm::demangle(F->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); + } + if (NoFreeDemangles.count(demangledName)) + return F; + switch (F->getIntrinsicID()) { case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: @@ -5972,23 +6002,23 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (EnzymeEmptyFnInactive) { return F; } + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No create nofree of empty function (" << demangledName << ") " + << F->getName() << ")\n"; + if (context.req) { + ss << " at context: " << *context.req; + } else { + ss << *F << "\n"; + } if (CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of empty function " << F->getName() << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *F << "\n"; - } CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(F), wrap(context.ip)); return F; } if (context.req) { - EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, - "Cannot create nofree of empty function: ", *F); + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, s); return F; } llvm::errs() << " unhandled, create no free of empty function: " << *F From f1f269a567b09db936b3eac8fe885c7dee7e3003 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 10 Feb 2024 10:42:03 -0500 Subject: [PATCH 027/106] Fix cached gemv (#1689) * Fix cached gemv * fix --- enzyme/Enzyme/BlasDerivatives.td | 6 +++--- .../Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll | 2 +- .../ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll | 2 +- .../Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll | 4 ++-- enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll | 2 +- enzyme/test/Integration/ReverseMode/blas.cpp | 6 +++--- enzyme/test/Integration/blasinfra.h | 2 +- enzyme/tools/enzyme-tblgen/caching.cpp | 2 +- 9 files changed, 15 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 70682e540ef5..01894e87e453 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -192,7 +192,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ ["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>], [ /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]> - (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), + (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $m), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), (b<"dot"> (Rows $transa, $m, $n), adj<"y">, use<"Ax">, ConstantInt<1>)), //if (is_normal $transa) { @@ -201,7 +201,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ // call sger(m, n, alpha, x, incx, ya, incy, Aa, lda) //} /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, (Concat adj<"y">, $x), (Concat $x, adj<"y">)), adj<"A">), - /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">), + /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $m), adj<"y">, Constant<"1.0">, adj<"x">), /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">), /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">) ] @@ -218,7 +218,7 @@ def ger : CallBlasPattern<(Op $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, >; //(ld $A, $transa, $lda, $m, $k) // if (cache_A) { -// ld_A = (arg_transa == 'N') ? arg_m : arg_k; +// ld_A = (arg_transa == 'N') ? arg_k : arg_m; // } else { // ld_A = arg_lda; // } diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index 0918ca43f83f..8f1c4f3e566e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -109,7 +109,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i3]]) ; CHECK-NEXT: %[[i10:.+]] = bitcast i8* %m_p to i64* ; CHECK-NEXT: %[[i11:.+]] = load i64, i64* %[[i10]] ; CHECK-NEXT: %[[i12:.+]] = bitcast i8* %n_p to i64* @@ -119,7 +119,7 @@ entry: ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i64 %mallocsize1) ; CHECK-NEXT: %cache.C = bitcast i8* %malloccall2 to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage3 -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %n_p) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %m_p) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index f58a4b2e209a..64d862b240ff 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %21) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) ; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index 2a7bbed40c24..f7164d4926e8 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %21) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) ; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index 964b53b1b925..d0f9ed397b0d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -105,7 +105,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] -; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage]], i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage]], i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i3]]) ; CHECK-NEXT: %loaded.trans1 = load i8, i8* %transb ; CHECK-DAG: %[[i10:.+]] = icmp eq i8 %loaded.trans1, 78 ; CHECK-DAG: %[[i11:.+]] = icmp eq i8 %loaded.trans1, 110 @@ -121,7 +121,7 @@ entry: ; CHECK-NEXT: %[[malloccall2:.+]] = tail call noalias nonnull i8* @malloc(i64 %[[mallocsize1]]) ; CHECK-NEXT: %cache.B = bitcast i8* %[[malloccall2]] to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage4 -; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %[[i13]], i8* %[[i14]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[i14]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %[[i13]], i8* %[[i14]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[i13]]) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index a4c00d61bbc4..09189c1fe8d1 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -103,7 +103,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.B = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[z3]], i8* %[[z4]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[z4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[z3]], i8* %[[z4]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[z3]]) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %ptr = bitcast i8* %B to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8 diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 8c929ed054d0..1c01178b874f 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -226,8 +226,8 @@ static void gemvTests() { assert(foundCalls.size() > 2); auto A_cache = (double*)foundCalls[0].pout_arg1; - cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, N); - inputs[4] = BlasInfo(A_cache, layout, M, N, N); + cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, M); + inputs[4] = BlasInfo(A_cache, layout, M, N, M); auto B_cache = (double*)foundCalls[1].pout_arg1; cblas_dcopy(trans ? M : N, B, incB, B_cache, 1); inputs[5] = BlasInfo(B_cache, trans ? M : N, 1); @@ -244,7 +244,7 @@ static void gemvTests() { lda); // dB = alpha * trans(A) * dC + dB - cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, N, dC, incC, 1.0, dB, incB); + cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, M, dC, incC, 1.0, dB, incB); // dY = beta * dY cblas_dscal(trans ? N : M, beta, dC, incC); diff --git a/enzyme/test/Integration/blasinfra.h b/enzyme/test/Integration/blasinfra.h index 07d540a9b326..cc9d2cc01dd9 100644 --- a/enzyme/test/Integration/blasinfra.h +++ b/enzyme/test/Integration/blasinfra.h @@ -1,5 +1,5 @@ -#include +#include #include #include #include diff --git a/enzyme/tools/enzyme-tblgen/caching.cpp b/enzyme/tools/enzyme-tblgen/caching.cpp index cdb7b3a92910..7a98d944b8ab 100644 --- a/enzyme/tools/enzyme-tblgen/caching.cpp +++ b/enzyme/tools/enzyme-tblgen/caching.cpp @@ -293,7 +293,7 @@ os << " if (byRef) valueTypes[" << len_pos << "] = ValueType::Primal;\n"; os << " if (EnzymeLapackCopy) {\n" << " Value *uplo = llvm::ConstantInt::get(charTy, 0);\n" // garbage data, just should not match U or L << " uplo = to_blas_callconv(BuilderZ, uplo, byRef, cublas, nullptr, allocationBuilder, \"copy.garbage\");\n" -<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, N};\n" +<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" << " if (!byRef) {\n" << " args.insert(args.begin(), arg_layout); valueTypes.insert(valueTypes.begin(), ValueType::Primal); }\n" << " callMemcpyStridedLapack(BuilderZ, *gutils->oldFunc->getParent(), blas, args, gutils->getInvertedBundles(&call, valueTypes, BuilderZ, /*lookup*/false));\n" From c199e36cf3cb7db48a425fdfb2b351013b0db4af Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 10 Feb 2024 15:30:36 -0500 Subject: [PATCH 028/106] Turn mask assertion error into nice compile error (#1690) --- enzyme/Enzyme/DiffeGradientUtils.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index f3ca78dacea1..43a14051c4fd 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -928,11 +928,20 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, applyChainRule(PointerType::get(addingType, 1), BuilderM, rule, ptr); } - assert(!mask); if (mask) { - llvm::errs() << "unhandled masked atomic fadd on llvm version " << *ptr - << " " << *dif << " mask: " << *mask << "\n"; - llvm_unreachable("unhandled masked atomic fadd"); + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Unimplemented masked atomic fadd for ptr:" << *ptr + << " dif:" << *dif << " mask: " << *mask << " orig: " << *orig << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(orig), + ErrorType::NoDerivative, this, nullptr, + wrap(&BuilderM)); + return; + } else { + EmitFailure("NoDerivative", orig->getDebugLoc(), orig, ss.str()); + return; + } } /* From 90d057f8bc5412bdbd7693266de8db289073685e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 16:37:34 -0500 Subject: [PATCH 029/106] [ActivityAnalysis] simplify and cleanup (#1694) --- enzyme/Enzyme/ActivityAnalysis.cpp | 376 +++++++++++------------------ 1 file changed, 139 insertions(+), 237 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 7921d3ff65c0..e6b5a565c112 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -104,19 +104,6 @@ cl::opt EnzymeEnableRecursiveHypotheses( #include // clang-format off -static const char *KnownInactiveFunctionsStartingWith[] = { - "f90io", - "$ss5print", - "_ZTv0_n24_NSoD", //"1Ev, 0Ev - "_ZNSt16allocator_traitsISaIdEE10deallocate", - "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev", -}; - -static const char *KnownInactiveFunctionsContains[] = { - "__enzyme_float", "__enzyme_double", "__enzyme_integer", - "__enzyme_pointer"}; - static const StringSet<> InactiveGlobals = { "small_typeof", "ompi_request_null", @@ -168,19 +155,26 @@ const llvm::StringMap MPIInactiveCommAllocators = { {"MPI_Comm_idup", 1}, {"MPI_Comm_join", 1}, }; +// clang-format on -// Instructions which themselves are inactive -// the returned value, however, may still be active -static const StringSet<> KnownInactiveFunctionInsts = { - "__dynamic_cast", - "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base", - "jl_ptr_to_array", - "jl_ptr_to_array_1d"}; +/// Return whether the call is always inactive by definition. +bool isInactiveCall(CallBase &CI) { + + // clang-format off +const char *KnownInactiveFunctionsStartingWith[] = { + "f90io", + "$ss5print", + "_ZTv0_n24_NSoD", //"1Ev, 0Ev + "_ZNSt16allocator_traitsISaIdEE10deallocate", + "_ZNSaIcED1Ev", + "_ZNSaIcEC1Ev", +}; + +const char *KnownInactiveFunctionsContains[] = { + "__enzyme_float", "__enzyme_double", "__enzyme_integer", + "__enzyme_pointer"}; -static const StringSet<> KnownInactiveFunctions = { +const StringSet<> KnownInactiveFunctions = { "mpfr_greater_p", "__nv_isnand", "__nv_isnanf", @@ -293,7 +287,7 @@ static const StringSet<> KnownInactiveFunctions = { "floorl" }; -static const std::set KnownInactiveIntrinsics = { +const std::set KnownInactiveIntrinsics = { #if LLVM_VERSION_MAJOR >= 12 Intrinsic::experimental_noalias_scope_decl, #endif @@ -342,7 +336,7 @@ static const std::set KnownInactiveIntrinsics = { Intrinsic::is_constant, Intrinsic::memset}; -static const char *DemangledKnownInactiveFunctionsStartingWith[] = { +const char *DemangledKnownInactiveFunctionsStartingWith[] = { // TODO this returns allocated memory and thus can be an active value // "std::allocator", "std::chrono::_V2::steady_clock::now", @@ -417,13 +411,110 @@ static const char *DemangledKnownInactiveFunctionsStartingWith[] = { "std::__detail::_Prime_rehash_policy", "std::__detail::_Hash_code_base", }; -// clang-format on + // clang-format on + + if (CI.hasFnAttr("enzyme_inactive")) + return true; + + if (auto iasm = dyn_cast(CI.getCalledOperand())) { + if (StringRef(iasm->getAsmString()).contains("exit") || + StringRef(iasm->getAsmString()).contains("cpuid")) + return false; + } + + auto F = getFunctionFromCall(&CI); + + if (F == nullptr) + return false; + + if (F->hasFnAttribute("enzyme_inactive")) { + return true; + } + + auto Name = getFuncNameFromCall(&CI); + + std::string demangledName = llvm::demangle(Name.str()); + auto dName = StringRef(demangledName); + for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { + if (startsWith(dName, FuncName)) { + return true; + } + } + + for (auto FuncName : KnownInactiveFunctionsStartingWith) { + if (startsWith(Name, FuncName)) { + return true; + } + } + + for (auto FuncName : KnownInactiveFunctionsContains) { + if (Name.contains(FuncName)) { + return true; + } + } + if (KnownInactiveFunctions.count(Name)) { + return true; + } + + if (MPIInactiveCommAllocators.find(Name) != MPIInactiveCommAllocators.end()) { + return true; + } + Intrinsic::ID ID; + if (isMemFreeLibMFunction(Name, &ID)) + if (KnownInactiveIntrinsics.count(ID)) { + return true; + } + + if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { + return true; + } + + return false; +} + +bool isInactiveCallInst(CallBase &CB, llvm::TargetLibraryInfo &TLI) { + // clang-format off +// Instructions which themselves are inactive +// the returned value, however, may still be active +static const StringSet<> KnownInactiveFunctionInsts = { + "__dynamic_cast", + "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base", + "jl_ptr_to_array", + "jl_ptr_to_array_1d"}; + // clang-format on + if (isInactiveCall(CB)) + return true; + if (CB.hasFnAttr("enzyme_inactive_inst")) { + return true; + } + auto called = getFunctionFromCall(&CB); + + if (called) { + if (called->hasFnAttribute("enzyme_inactive_inst")) { + return true; + } + } + + auto funcName = getFuncNameFromCall(&CB); + if (KnownInactiveFunctionInsts.count(funcName)) + return true; + + if (isAllocationFunction(funcName, TLI) || + isDeallocationFunction(funcName, TLI)) { + return true; + } + + return false; +} /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { assert(directions & DOWN); - if (CI->hasFnAttr("enzyme_inactive")) + if (isInactiveCall(*CI)) return true; auto F = getFunctionFromCall(CI); @@ -453,10 +544,6 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (F == nullptr) return false; - if (F->hasFnAttribute("enzyme_inactive")) { - return true; - } - auto Name = getFuncNameFromCall(CI); // Only the 1-th arg impacts activity @@ -468,43 +555,6 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI)) return true; - std::string demangledName = llvm::demangle(Name.str()); - auto dName = StringRef(demangledName); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return true; - } - } - if (demangledName == Name.str()) { - // Either demangeling failed - // or they are equal but matching failed - // if (!startsWith(Name, "llvm.")) - // llvm::errs() << "matching failed: " << Name.str() << " " - // << demangledName << "\n"; - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(Name, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (Name.contains(FuncName)) { - return true; - } - } - if (KnownInactiveFunctions.count(Name)) { - return true; - } - - if (MPIInactiveCommAllocators.find(Name) != MPIInactiveCommAllocators.end()) { - return true; - } - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return true; - } - /// Only the first argument (magnitude) of copysign is active if (F->getIntrinsicID() == Intrinsic::copysign && CI->getArgOperand(0) != val) { @@ -551,6 +601,8 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { static inline void propagateArgumentInformation( TargetLibraryInfo &TLI, CallInst &CI, llvm::function_ref propagateFromOperand) { + if (isInactiveCall(CI)) + return; // These functions are known to only have the first argument impact // the activity of the call instruction @@ -598,12 +650,6 @@ static inline void propagateArgumentInformation( return; } - // Certain intrinsics are inactive by definition - // and have nothing to propagate. - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return; - } - if (F->getIntrinsicID() == Intrinsic::memcpy || F->getIntrinsicID() == Intrinsic::memmove) { propagateFromOperand(CI.getOperand(0)); @@ -720,13 +766,6 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, ActiveInstructions.insert(I); return false; } - if (CI->hasFnAttr("enzyme_inactive") || - CI->hasFnAttr("enzyme_inactive_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } auto called = getFunctionFromCall(CI); if (called) { @@ -737,27 +776,17 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, ActiveInstructions.insert(I); return false; } - if (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } } - if (KnownInactiveFunctionInsts.count(getFuncNameFromCall(CI))) { + if (isInactiveCallInst(*CI, TLI)) { + if (EnzymePrintActivity) + llvm::errs() << "known inactive instruction from call " << *I << "\n"; InsertConstantInstruction(TR, I); return true; } } if (auto II = dyn_cast(I)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - if (EnzymePrintActivity) - llvm::errs() << "known inactive intrinsic " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } else if (isIntelSubscriptIntrinsic(*II)) { + if (isIntelSubscriptIntrinsic(*II)) { // The intrinsic "llvm.intel.subscript" does not propogate deriviative // information directly. But its returned pointer may be active. InsertConstantInstruction(TR, I); @@ -1066,13 +1095,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - if (auto II = dyn_cast(Val)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - InsertConstantValue(TR, Val); - return true; - } - } - // All arguments must be marked constant/nonconstant ahead of time if (isa(Val) && !cast(Val)->hasByValAttr()) { llvm::errs() << *(cast(Val)->getParent()) << "\n"; @@ -1348,6 +1370,12 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } } + if (isInactiveCall(*CI)) { + if (EnzymePrintActivity) + llvm::errs() << "known inactive val from call" << *Val << "\n"; + InsertConstantValue(TR, Val); + return true; + } } if (auto BO = dyn_cast(Val)) { // x & 0b100000 is definitionally inactive @@ -1536,8 +1564,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } } } else if (auto op = dyn_cast(TmpOrig)) { - if (op->hasFnAttr("enzyme_inactive") || - op->hasFnAttr("enzyme_inactive_val") || + if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") || op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, "enzyme_inactive")) { InsertConstantValue(TR, Val); @@ -1549,8 +1576,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { StringRef funcName = getFuncNameFromCall(op); if (called && - (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_val") || + (called->hasFnAttribute("enzyme_inactive_val") || called->getAttributes().hasAttribute( llvm::AttributeList::ReturnIndex, "enzyme_inactive"))) { InsertConstantValue(TR, Val); @@ -1564,45 +1590,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - - if (called && called->getIntrinsicID() == Intrinsic::trap) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - // If requesting empty unknown functions to be considered inactive, // abide by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -1857,57 +1844,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { // If this is a malloc or free, this doesn't impact the activity if (auto CI = dyn_cast(I)) { - if (CI->hasFnAttr("enzyme_inactive") || - CI->hasFnAttr("enzyme_inactive_inst")) + if (isInactiveCallInst(*CI, TLI)) return false; - if (auto iasm = dyn_cast(CI->getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("exit") || - StringRef(iasm->getAsmString()).contains("cpuid")) - return false; - } - - auto F = getFunctionFromCall(CI); StringRef funcName = getFuncNameFromCall(CI); - - if (F && (F->hasFnAttribute("enzyme_inactive") || - F->hasFnAttribute("enzyme_inactive_inst"))) { - return false; - } - if (isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return false; - } - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - return false; - } - if (KnownInactiveFunctionInsts.count(funcName)) { - return false; - } if (isMemFreeLibMFunction(funcName)) { return false; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return false; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - return false; - } - } - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - return false; - } - } - if (funcName == "__cxa_guard_acquire" || funcName == "__cxa_guard_release" || funcName == "__cxa_guard_abort" || funcName == "posix_memalign" || @@ -1917,12 +1861,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { funcName == "cudaMallocFromPoolAsync") { return false; } - - if (F) { - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return false; - } - } } Value *memval = Val; @@ -2538,8 +2476,10 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, } if (auto op = dyn_cast(inst)) { - if (op->hasFnAttr("enzyme_inactive") || - op->hasFnAttr("enzyme_inactive_val")) { + if (isInactiveCall(*op)) + return true; + + if (op->hasFnAttr("enzyme_inactive_val")) { return true; } // Calls to print/assert/cxa guard are definitionally inactive @@ -2548,8 +2488,7 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, StringRef funcName = getFuncNameFromCall(op); auto called = getFunctionFromCall(op); - if (called && (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_val"))) { + if (called && (called->hasFnAttribute("enzyme_inactive_val"))) { return true; } if (funcName == "free" || funcName == "_ZdlPv" || funcName == "_ZdlPvm" || @@ -2557,37 +2496,6 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, return true; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - return true; - } - } - - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions - << ") up-knowninactivecall " << *inst << "\n"; - return true; - } - - if (called && called->getIntrinsicID() == Intrinsic::trap) - return true; - // If requesting empty unknown functions to be considered inactive, abide // by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -2609,12 +2517,6 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, } // Intrinsics known always to be inactive if (auto II = dyn_cast(inst)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - << *inst << "\n"; - return true; - } if (isIntelSubscriptIntrinsic(*II)) { // The only argument that can make an llvm.intel.subscript intrinsic // active is the pointer operand From 1d7347865e88c1e6a306bb665493d35948401a9a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 16:57:58 -0500 Subject: [PATCH 030/106] Mark nvfmodf and improve error message for wrong arg num (#1695) --- enzyme/Enzyme/EnzymeLogic.cpp | 49 +++++++++++++++++++++++-- enzyme/Enzyme/InstructionDerivatives.td | 2 +- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index d60ac934634b..17ab0c2f4eca 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -1961,11 +1961,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( assert(!todiff->getReturnType()->isEmptyTy() && !todiff->getReturnType()->isVoidTy()); - assert(_overwritten_args.size() == todiff->arg_size()); - FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(oldTypeInfo_, todiff); - - assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); AugmentedCacheKey tup = {todiff, retType, constant_args, _overwritten_args, returnUsed, shadowReturnUsed, @@ -1973,6 +1969,51 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( AtomicAdd, omp, width}; + if (_overwritten_args.size() != todiff->arg_size()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << " overwritten_args.size() [" << _overwritten_args.size() + << "] != todiff->arg_size()\n"; + ss << "todiff: " << *todiff << "\n"; + llvm::Value *toshow = todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *todiff << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(todiff), + wrap(context.ip)); + auto newFunc = todiff; + std::map returnMapping; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + auto newFunc = todiff; + std::map returnMapping; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; + } + llvm::errs() << "mod: " << *todiff->getParent() << "\n"; + llvm::errs() << *todiff << "\n"; + llvm_unreachable( + "attempting to differentiate function with wrong overwritten count"); + } + + assert(_overwritten_args.size() == todiff->arg_size()); + assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); + auto found = AugmentedCachedFunctions.find(tup); if (found != AugmentedCachedFunctions.end()) { return found->second; diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 00be8fe77e55..58731be4a532 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -434,7 +434,7 @@ def : CallPattern<(Op $x), >; def : CallPattern<(Op $x, $y), - ["fmod", "fmodf", "fmodl"], + ["fmod", "fmodf", "fmodl", "__nv_fmod", "__nv_fmodf", "__nv_fmodl"], [ (DiffeRet), (CheckedMul (DiffeRet), (FNeg (Intrinsic<"copysign"> (Intrinsic<"floor"> (Intrinsic<"fabs"> (FDiv $x, $y):$div)), $div))) From 6b8960d17a76e5f7bbbfdb44a57f153fee1c0681 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 18:06:26 -0500 Subject: [PATCH 031/106] Simplify and improve differential use analysis (#1693) * Simplify and improve differential use analysis * fix lambda --------- Co-authored-by: Tim Gymnich --- enzyme/Enzyme/AdjointGenerator.h | 16 +- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 78 ++-------- enzyme/Enzyme/DifferentialUseAnalysis.h | 147 +++++++++++++++---- enzyme/Enzyme/EnzymeLogic.cpp | 4 +- enzyme/Enzyme/GradientUtils.cpp | 52 ++++--- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 56 +++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h | 10 ++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 21 ++- 8 files changed, 252 insertions(+), 132 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index de9bbf8e38a6..ec9cd5c51c9a 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -756,7 +756,7 @@ class AdjointGenerator : public llvm::InstVisitor { auto alignment = LI.getAlign(); auto &DL = gutils->newFunc->getParent()->getDataLayout(); - bool constantval = parseTBAA(LI, DL, nullptr).Inner0().isIntegral(); + bool constantval = parseTBAA(LI, DL, nullptr)[{-1}].isIntegral(); visitLoadLike(LI, alignment, constantval); eraseIfUnused(LI); } @@ -992,7 +992,7 @@ class AdjointGenerator : public llvm::InstVisitor { NewI->setMetadata(LLVMContext::MD_noalias, noscope); bool constantval = gutils->isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); IRBuilder<> BuilderZ(NewI); BuilderZ.setFastMathFlags(getFast()); @@ -3272,7 +3272,7 @@ class AdjointGenerator : public llvm::InstVisitor { // copying into nullptr is invalid (not sure why it exists here), but we // shouldn't do it in reverse pass or shadow if (isa(orig_dst) || - TR.query(orig_dst).Inner0() == BaseType::Anything) { + TR.query(orig_dst)[{-1}] == BaseType::Anything) { eraseIfUnused(MTI); return; } @@ -3689,7 +3689,7 @@ class AdjointGenerator : public llvm::InstVisitor { auto align0 = cast(I.getOperand(1))->getZExtValue(); auto align = MaybeAlign(align0); auto &DL = gutils->newFunc->getParent()->getDataLayout(); - bool constantval = parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + bool constantval = parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); visitLoadLike(I, align, constantval, /*mask*/ gutils->getNewFromOriginal(I.getOperand(2)), /*orig_maskInit*/ I.getOperand(3)); @@ -4068,7 +4068,7 @@ class AdjointGenerator : public llvm::InstVisitor { assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); } else { - assert(TR.query(call.getArgOperand(i)).Inner0().isFloat()); + assert(TR.query(call.getArgOperand(i))[{-1}].isFloat()); OutTypes.push_back(call.getArgOperand(i)); OutFPTypes.push_back(argType); assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF || @@ -5437,8 +5437,7 @@ class AdjointGenerator : public llvm::InstVisitor { assert(dcall); if (!gutils->isConstantValue(&call)) { - if (!call.getType()->isFPOrFPVectorTy() && - TR.query(&call).Inner0().isPossiblePointer()) { + if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { } else if (Mode != DerivativeMode::ReverseModePrimal) { ((DiffeGradientUtils *)gutils)->differentials[dcall] = ((DiffeGradientUtils *)gutils)->differentials[newCall]; @@ -5898,8 +5897,7 @@ class AdjointGenerator : public llvm::InstVisitor { gutils->originalToNewFn[&call] = dcall; gutils->newToOriginalFn.erase(newCall); gutils->newToOriginalFn[dcall] = &call; - if (!call.getType()->isFPOrFPVectorTy() && - TR.query(&call).Inner0().isPossiblePointer()) { + if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { } else { ((DiffeGradientUtils *)gutils)->differentials[dcall] = ((DiffeGradientUtils *)gutils)->differentials[newCall]; diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 3be0cdfc2895..d20ec4eefbc4 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -293,47 +293,6 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( return false; } - if (shadow) { - if (auto IVI = dyn_cast(user)) - if (isa(IVI) || isa(IVI)) { - if (IVI->getOperand(1) == val) { - SmallVector todo; - todo.push_back(IVI); - SmallVector, 1> - done; - while (todo.size()) { - auto cur = todo.pop_back_val(); - for (auto u : cur->users()) { - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - if (auto IVI2 = dyn_cast(u)) { - todo.push_back(IVI2); - continue; - } - done.emplace_back(cast(u), cur); - } - } - for (auto &pair : done) { - - bool direct = is_use_directly_needed_in_reverse( - gutils, pair.second, mode, pair.first, oldUnreachable, - QueryType::Shadow, recursiveUse); - if (direct) { - - if (EnzymePrintDiffUse) - llvm::errs() - << " Need (partial) direct " << to_string(qtype) << " of " - << *val << " in reverse from insertelem " << *user - << " via " << *pair.second << " in " << *pair.first << "\n"; - return true; - } - } - } - } - } - if (!shadow) if (auto IEI = dyn_cast(user)) { // Only need the index in the reverse, so if the value is not @@ -800,33 +759,26 @@ int DifferentialUseAnalysis::cmpLoopNest(Loop *prev, Loop *next) { return -1; } -void DifferentialUseAnalysis::minCut( - const DataLayout &DL, LoopInfo &OrigLI, - const SetVector &Recomputes, - const SetVector &Intermediates, SetVector &Required, - SetVector &MinReq, - const ValueMap - &rematerializableAllocations, - llvm::TargetLibraryInfo &TLI) { +void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, + const SetVector &Recomputes, + const SetVector &Intermediates, + SetVector &Required, + SetVector &MinReq, + const GradientUtils *gutils, + llvm::TargetLibraryInfo &TLI) { Graph G; for (auto V : Intermediates) { G[Node(V, false)].insert(Node(V, true)); - for (auto U : V->users()) { - if (auto I = dyn_cast(U)) { - for (auto pair : rematerializableAllocations) { - if (Intermediates.count(pair.first) && pair.second.stores.count(I)) { - if (V != pair.first) - G[Node(V, true)].insert(Node(pair.first, false)); + forEachDifferentialUser( + [&](Value *U) { + if (Intermediates.count(U)) { + if (V != U) + G[Node(V, true)].insert(Node(U, false)); } - } - } - if (Intermediates.count(U)) { - if (V != U) - G[Node(V, true)].insert(Node(U, false)); - } - } + }, + gutils, V); } - for (auto pair : rematerializableAllocations) { + for (auto pair : gutils->rematerializableAllocations) { if (Intermediates.count(pair.first)) { for (LoadInst *L : pair.second.loads) { if (Intermediates.count(L)) { diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index d9d689eeee64..565e18724869 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -166,6 +166,60 @@ inline bool is_value_needed_in_reverse( } } + if (!TR.anyFloat(const_cast(inst))) + if (auto IVI = dyn_cast(user)) { + bool inserted = false; + if (auto II = dyn_cast(IVI)) + inserted = II->getInsertedValueOperand() == inst || + II->getAggregateOperand() == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getAggregateOperand() == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(1) == inst || II->getOperand(0) == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(0) == inst; + if (inserted) { + SmallVector todo; + todo.push_back(IVI); + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + + bool partial = false; + if (!gutils->isConstantValue(const_cast(cur))) { + partial = is_value_needed_in_reverse( + gutils, user, mode, seen, oldUnreachable); + } + if (partial) { + + if (EnzymePrintDiffUse) + llvm::errs() + << " Need (partial) direct " << to_string(VT) << " of " + << *inst << " in reverse from insertelem " << *user + << " via " << *cur << " in " << *u << "\n"; + return seen[idx] = true; + } + } + } + } + } + if (VT != QueryType::Primal) continue; } @@ -332,36 +386,14 @@ inline bool is_value_needed_in_reverse( primalUsedInShadowPointer = false; } } - if (auto IVI = dyn_cast(user)) { - bool valueIsIndex = false; - for (unsigned i = 2; i < IVI->getNumOperands(); ++i) { - if (IVI->getOperand(i) == inst) { - if (inst == IVI->getInsertedValueOperand() && - TR.query( - const_cast(IVI->getInsertedValueOperand()))[{-1}] - .isFloat()) { - continue; - } - valueIsIndex = true; - } - } - primalUsedInShadowPointer = valueIsIndex; - } - if (auto EVI = dyn_cast(user)) { - bool valueIsIndex = false; - for (unsigned i = 1; i < EVI->getNumOperands(); ++i) { - if (EVI->getOperand(i) == inst) { - valueIsIndex = true; - } - } - primalUsedInShadowPointer = valueIsIndex; - } + // No need for insert/extractvalue since indices are unsigned + // not llvm runtime values + if (isa(user) || isa(user)) + primalUsedInShadowPointer = false; if (primalUsedInShadowPointer) if (!user->getType()->isVoidTy() && - TR.query(const_cast(user)) - .Inner0() - .isPossiblePointer()) { + TR.anyPointer(const_cast(user))) { if (is_value_needed_in_reverse( gutils, user, mode, seen, oldUnreachable)) { if (EnzymePrintDiffUse) @@ -433,11 +465,66 @@ void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, const llvm::SetVector &Recomputes, const llvm::SetVector &Intermediates, llvm::SetVector &Required, - llvm::SetVector &MinReq, - const llvm::ValueMap - &rematerializableAllocations, + llvm::SetVector &MinReq, const GradientUtils *gutils, llvm::TargetLibraryInfo &TLI); +__attribute__((always_inline)) static inline void +forEachDirectInsertUser(llvm::function_ref f, + const GradientUtils *gutils, llvm::Instruction *IVI, + llvm::Value *val, bool useCheck) { + using namespace llvm; + if (!gutils->isConstantValue(IVI)) + return; + bool inserted = false; + if (auto II = dyn_cast(IVI)) + inserted = II->getInsertedValueOperand() == val || + II->getAggregateOperand() == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getAggregateOperand() == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(1) == val || II->getOperand(0) == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(0) == val; + if (inserted) { + SmallVector todo; + todo.push_back(IVI); + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (isa(u) || isa(u) || + isa(u) || isa(u)) { + auto I2 = cast(u); + bool subCheck = useCheck; + if (!subCheck) { + subCheck = is_value_needed_in_reverse( + gutils, I2, gutils->mode, gutils->notForAnalysis); + } + if (subCheck) + f(I2); + todo.push_back(I2); + continue; + } + } + } + } +} + +__attribute__((always_inline)) static inline void +forEachDifferentialUser(llvm::function_ref f, + const GradientUtils *gutils, llvm::Value *V, + bool useCheck = false) { + for (auto V2 : V->users()) { + if (auto Inst = llvm::dyn_cast(V2)) { + for (const auto &pair : gutils->rematerializableAllocations) { + if (pair.second.stores.count(Inst)) { + f(llvm::cast(pair.first)); + } + } + f(Inst); + forEachDirectInsertUser(f, gutils, Inst, V, useCheck); + } + } +} }; // namespace DifferentialUseAnalysis #endif diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 17ab0c2f4eca..850a92543db3 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -1282,7 +1282,7 @@ bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils) { } if (!op->getType()->isFPOrFPVectorTy() && !gutils->isConstantValue(op) && - gutils->TR.query(op).Inner0().isPossiblePointer()) { + gutils->TR.anyPointer(op)) { modifyPrimal = true; #ifdef PRINT_AUGCALL @@ -1315,7 +1315,7 @@ bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils) { if (!argType->isFPOrFPVectorTy() && !gutils->isConstantValue(op->getArgOperand(i)) && - gutils->TR.query(op->getArgOperand(i)).Inner0().isPossiblePointer()) { + gutils->TR.anyPointer(op->getArgOperand(i))) { if (!isReadOnly(op, i)) { modifyPrimal = true; #ifdef PRINT_AUGCALL diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index ec15cd2e76ff..734d95989442 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3280,7 +3280,7 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto &DL = newFunc->getParent()->getDataLayout(); bool constantval = isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); // TODO allow recognition of other types that could contain // pointers [e.g. {void*, void*} or <2 x i64> ] @@ -4321,8 +4321,7 @@ DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::Value *orig, subretType = DIFFE_TYPE::DUP_ARG; shadowReturnUsed = true; } else { - if (!orig->getType()->isFPOrFPVectorTy() && - TR.query(orig).Inner0().isPossiblePointer()) { + if (!orig->getType()->isFPOrFPVectorTy() && TR.anyPointer(orig)) { if (DifferentialUseAnalysis::is_value_needed_in_reverse< QueryType::Shadow>(this, orig, cmode, notForAnalysis)) { subretType = DIFFE_TYPE::DUP_ARG; @@ -4359,8 +4358,7 @@ DIFFE_TYPE GradientUtils::getDiffeType(Value *v, bool foreignFunction) const { auto argType = v->getType(); - if (!argType->isFPOrFPVectorTy() && - (TR.query(v).Inner0().isPossiblePointer() || foreignFunction)) { + if (!argType->isFPOrFPVectorTy() && (TR.anyPointer(v) || foreignFunction)) { if (argType->isPointerTy()) { auto at = getBaseObject(v); if (auto arg = dyn_cast(at)) { @@ -5105,9 +5103,21 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, return applyChainRule(oval->getType(), BuilderM, rule); } - if (isConstantValue(oval) && !isa(oval) && - !isa(oval) && !isa(oval) && - !isa(oval)) { + bool shouldNullShadow = isConstantValue(oval); + if (shouldNullShadow) { + if (isa(oval) || isa(oval) || + isa(oval) || isa(oval)) { + shouldNullShadow = false; + auto orig = cast(oval); + if (knownRecomputeHeuristic.count(orig)) { + if (!knownRecomputeHeuristic[orig]) { + shouldNullShadow = true; + } + } + } + } + + if (shouldNullShadow) { // NOTE, this is legal and the correct resolution, however, our activity // analysis honeypot no longer exists @@ -8011,9 +8021,9 @@ void GradientUtils::computeMinCache() { todo.pop_front(); if (Intermediates.count(V)) continue; - if (!DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(this, V, minCutMode, FullSeen, - notForAnalysis)) { + bool multiLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< + QueryType::Primal>(this, V, minCutMode, FullSeen, notForAnalysis); + if (!multiLevel) { continue; } if (!Recomputes.count(V)) { @@ -8033,27 +8043,21 @@ void GradientUtils::computeMinCache() { } } Intermediates.insert(V); - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal, /*OneLevel*/ true>( - this, V, minCutMode, OneLevelSeen, notForAnalysis)) { + bool singleLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< + QueryType::Primal, /*OneLevel*/ true>(this, V, minCutMode, + OneLevelSeen, notForAnalysis); + if (singleLevel) { Required.insert(V); } else { - for (auto V2 : V->users()) { - if (auto Inst = dyn_cast(V2)) - for (auto pair : rematerializableAllocations) { - if (pair.second.stores.count(Inst)) { - todo.push_back(pair.first); - } - } - todo.push_back(V2); - } + DifferentialUseAnalysis::forEachDifferentialUser( + [&](Value *V2) { todo.push_back(V2); }, this, V); } } SetVector MinReq; DifferentialUseAnalysis::minCut(oldFunc->getParent()->getDataLayout(), OrigLI, Recomputes, Intermediates, Required, - MinReq, rematerializableAllocations, TLI); + MinReq, this, TLI); SmallPtrSet NeedGraph; for (Value *V : MinReq) NeedGraph.insert(V); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index b9e670326fd4..d5bf445ee429 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -5773,6 +5773,62 @@ TypeTree TypeResults::query(Value *val) const { return analyzer.getAnalysis(val); } +bool TypeResults::anyFloat(Value *val) const { + assert(val); + assert(val->getType()); + auto q = query(val); + auto dt = q[{-1}]; + if (dt != BaseType::Anything && dt != BaseType::Unknown) + return dt.isFloat(); + + size_t ObjSize = 1; + auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + if (val->getType()->isSized()) + ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; + + for (size_t i = 0; i < ObjSize;) { + dt = q[{(int)i}]; + if (dt == BaseType::Integer) { + i++; + continue; + } + if (dt == BaseType::Pointer) { + i += dl.getPointerSize(0); + continue; + } + return true; + } + return false; +} + +bool TypeResults::anyPointer(Value *val) const { + assert(val); + assert(val->getType()); + auto q = query(val); + auto dt = q[{-1}]; + if (dt != BaseType::Anything && dt != BaseType::Unknown) + return dt == BaseType::Pointer; + + size_t ObjSize = 1; + auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + if (val->getType()->isSized()) + ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; + + for (size_t i = 0; i < ObjSize;) { + dt = q[{(int)i}]; + if (dt == BaseType::Integer) { + i++; + continue; + } + if (auto FT = dt.isFloat()) { + i += (dl.getTypeSizeInBits(FT) + 7) / 8; + continue; + } + return true; + } + return false; +} + void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer.dump(ss); } ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h index fc7bc1d24dd7..6e036454143a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h @@ -174,6 +174,16 @@ class TypeResults { /// The TypeTree of a particular Value TypeTree query(llvm::Value *val) const; + /// Whether any part of the top level register can contain a float + /// e.g. { i64, float } can contain a float, but { i64, i8* } would not. + // Of course, here we compute with type analysis rather than llvm type + bool anyFloat(llvm::Value *val) const; + + /// Whether any part of the top level register can contain a pointer + /// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not. + // Of course, here we compute with type analysis rather than llvm type + bool anyPointer(llvm::Value *val) const; + /// The TypeInfo calling convention FnTypeInfo getAnalyzedTypeInfo() const; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index bf6a00d07cd2..7eb015121a3e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -510,6 +510,13 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, "{(llvm::Constant*)ConstantFP::get(ST->getElementType(0), \"" << rvalue->getValue() << "\"), (llvm::Constant*)ConstantFP::get(ST->getElementType(1), \"" + << ivalue->getValue() << "\")});\n" + << "} else if (auto AT = dyn_cast(ty)) {\n" + << curIndent << INDENT << INDENT + << "ret = ConstantArray::get(AT, " + "{(llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" + << rvalue->getValue() + << "\"), (llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" << ivalue->getValue() << "\")});\n"; os << curIndent << INDENT << "} else assert(0 && \"unhandled cfp\");\n"; os << curIndent << INDENT << "ret;\n"; @@ -2115,14 +2122,20 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, StringRef name = cast(lst->getValues()[0])->getValue(); if (lst->size() >= 2) { auto min = cast(lst->getValues()[1])->getValue(); - int min_int; - min.getAsInteger(10, min_int); + int min_int = 0; + if (min.size() != 0 && min.getAsInteger(10, min_int)) { + PrintFatalError(pattern->getLoc(), + "Could not parse min llvm version as int"); + } if (min.size() != 0 && LLVM_VERSION_MAJOR < min_int) continue; if (lst->size() >= 3) { auto max = cast(lst->getValues()[2])->getValue(); - int max_int; - max.getAsInteger(10, max_int); + int max_int = 0; + if (max.size() != 0 && max.getAsInteger(10, max_int)) { + PrintFatalError(pattern->getLoc(), + "Could not parse max llvm version as int"); + } if (max.size() != 0 && LLVM_VERSION_MAJOR > max_int) continue; } From 20221d266f6058fb24650c2c499f408d114b2766 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:28:17 -0500 Subject: [PATCH 032/106] Enable c++ warning for runtime activity (#1699) --- enzyme/Enzyme/AdjointGenerator.h | 22 ++++++++++------ enzyme/Enzyme/EnzymeLogic.cpp | 22 ++++++++++------ enzyme/Enzyme/GradientUtils.cpp | 44 ++++++++++++++++++++------------ 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ec9cd5c51c9a..ced5c498b6d0 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1051,7 +1051,7 @@ class AdjointGenerator : public llvm::InstVisitor { } Value *diff = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && constantval) { + if (!EnzymeRuntimeActivityCheck && constantval) { if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { if (!isa(orig_val) && !isa(orig_val)) { @@ -1059,9 +1059,12 @@ class AdjointGenerator : public llvm::InstVisitor { raw_string_ostream ss(str); ss << "Mismatched activity for: " << I << " const val: " << *orig_val; - diff = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils, - wrap(orig_val), wrap(&BuilderZ))); + if (CustomErrorHandler) + diff = unwrap(CustomErrorHandler( + str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils, + wrap(orig_val), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", I, ss.str()); } } } @@ -1280,7 +1283,7 @@ class AdjointGenerator : public llvm::InstVisitor { Value *valueop = nullptr; if (constantval) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler) { + if (!EnzymeRuntimeActivityCheck) { if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { if (!isa(orig_val) && !isa(orig_val)) { @@ -1288,9 +1291,12 @@ class AdjointGenerator : public llvm::InstVisitor { raw_string_ostream ss(str); ss << "Mismatched activity for: " << I << " const val: " << *orig_val; - valueop = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, - gutils, wrap(orig_val), wrap(&BuilderZ))); + if (CustomErrorHandler) + valueop = unwrap(CustomErrorHandler( + str.c_str(), wrap(&I), ErrorType::MixedActivityError, + gutils, wrap(orig_val), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", I, ss.str()); } } } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 850a92543db3..fbd078ce23c7 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2548,7 +2548,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( IRBuilder<> BuilderZ(newri); Value *invertri = nullptr; if (gutils->isConstantValue(orig_oldval)) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && gutils->TR.query(orig_oldval)[{-1}].isPossiblePointer()) { if (!isa(orig_oldval) && !isa(orig_oldval)) { @@ -2556,9 +2556,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( raw_string_ostream ss(str); ss << "Mismatched activity for: " << *ri << " const val: " << *orig_oldval; - invertri = unwrap(CustomErrorHandler( - str.c_str(), wrap(ri), ErrorType::MixedActivityError, - gutils, wrap(orig_oldval), wrap(&BuilderZ))); + if (CustomErrorHandler) + invertri = unwrap(CustomErrorHandler( + str.c_str(), wrap(ri), ErrorType::MixedActivityError, + gutils, wrap(orig_oldval), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", *ri, ss.str()); } } } @@ -3083,16 +3086,19 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB, if (!ret->getType()->isFPOrFPVectorTy() && TR.getReturnAnalysis().Inner0().isPossiblePointer()) { if (gutils->isConstantValue(ret)) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(ret)[{-1}].isPossiblePointer()) { if (!isa(ret) && !isa(ret)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *inst << " const val: " << *ret; - invertedPtr = unwrap(CustomErrorHandler( - str.c_str(), wrap(inst), ErrorType::MixedActivityError, gutils, - wrap(ret), wrap(&nBuilder))); + if (CustomErrorHandler) + invertedPtr = unwrap(CustomErrorHandler( + str.c_str(), wrap(inst), ErrorType::MixedActivityError, + gutils, wrap(ret), wrap(&nBuilder))); + else + EmitWarning("MixedActivityError", *inst, ss.str()); } } } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 734d95989442..0fd01c83a05e 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -5695,15 +5695,18 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *itval = nullptr; { auto tval = arg->getTrueValue(); - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(arg)[{-1}].isPossiblePointer() && !isa(tval) && !isa(tval) && isConstantValue(tval)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *arg << " const val: " << *tval; - itval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(tval), wrap(&bb))); + if (CustomErrorHandler) + itval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), + ErrorType::MixedActivityError, this, + wrap(tval), wrap(&bb))); + else + EmitWarning("MixedActivityError", *arg, ss.str()); } if (!itval) { itval = invertPointerM(tval, bb, nullShadow); @@ -5712,15 +5715,18 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *ifval = nullptr; { auto fval = arg->getFalseValue(); - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(arg)[{-1}].isPossiblePointer() && !isa(fval) && !isa(fval) && isConstantValue(fval)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *arg << " const val: " << *fval; - ifval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(fval), wrap(&bb))); + if (CustomErrorHandler) + ifval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), + ErrorType::MixedActivityError, this, + wrap(fval), wrap(&bb))); + else + EmitWarning("MixedActivityError", *arg, ss.str()); } if (!ifval) { ifval = invertPointerM(fval, bb, nullShadow); @@ -6064,7 +6070,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *preval = phi->getIncomingValue(j); Value *val = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(phi)[{-1}].isPossiblePointer() && !isa(preval) && !isa(preval) && isConstantValue(preval)) { @@ -6072,9 +6078,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, raw_string_ostream ss(str); ss << "Mismatched activity for: " << *phi << " const val: " << *preval; - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, this, - wrap(preval), wrap(&pre))); + if (CustomErrorHandler) + val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), + ErrorType::MixedActivityError, + this, wrap(preval), wrap(&pre))); + else + EmitWarning("MixedActivityError", *phi, ss.str()); } if (!val) { val = invertPointerM(preval, pre, nullShadow); @@ -6124,7 +6133,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *preval = phi->getIncomingValue(i); Value *val = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(phi)[{-1}].isPossiblePointer() && !isa(preval) && !isa(preval) && isConstantValue(preval)) { @@ -6132,9 +6141,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, raw_string_ostream ss(str); ss << "Mismatched activity for: " << *phi << " const val: " << *preval; - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, this, - wrap(preval), wrap(&pre))); + if (CustomErrorHandler) + val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), + ErrorType::MixedActivityError, + this, wrap(preval), wrap(&pre))); + else + EmitWarning("MixedActivityError", *phi, ss.str()); } if (!val) { val = invertPointerM(preval, pre, nullShadow); From a309cc083f64fb6c17689a81feffa804bc7fcb3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:29:31 -0500 Subject: [PATCH 033/106] Mark memcpy of 1 byte as inactive (#1698) * Mark memcpy of 1 byte as inactive * Update ActivityAnalysis.cpp --- enzyme/Enzyme/ActivityAnalysis.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index e6b5a565c112..459da4928560 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -468,6 +468,13 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { return true; } + // Copies of size 1 are inactive [cannot move differentiable data in one byte] + if (auto MTI = dyn_cast(&CI)) { + if (auto sz = dyn_cast(MTI->getOperand(2))) { + if (sz->getValue() == 1) + return true; + } + } return false; } From c356fbe124a5f916ee70964d71b67036f248822d Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:51:01 -0500 Subject: [PATCH 034/106] Improve nofree demangle support (#1696) * Improve nofree demangle support * fix * fix * fix --- enzyme/Enzyme/ActivityAnalysis.cpp | 15 ++- enzyme/Enzyme/AdjointGenerator.h | 4 +- enzyme/Enzyme/CallDerivatives.cpp | 3 +- enzyme/Enzyme/EnzymeLogic.cpp | 105 +++++++++++++++--- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 6 +- .../test/Integration/ReverseMode/blas_gemm2.c | 3 - .../Integration/ReverseMode/customlog1p.c | 4 - .../test/Integration/ReverseMode/inactivefn.c | 4 - enzyme/test/Integration/ReverseMode/invsqrt.c | 9 +- enzyme/test/Integration/ReverseMode/mycos.c | 18 ++- enzyme/test/Integration/ReverseMode/omp.c | 4 - enzyme/test/Integration/ReverseMode/omp2.c | 4 - enzyme/test/Integration/ReverseMode/omp3.c | 2 - enzyme/test/Integration/ReverseMode/omp6.c | 4 - enzyme/test/Integration/ReverseMode/omp_two.c | 4 - .../test/Integration/ReverseMode/ompbound.c | 4 - enzyme/test/Integration/test_utils.h | 4 +- 17 files changed, 121 insertions(+), 76 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 459da4928560..ef6c7a14f811 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -115,9 +115,11 @@ static const StringSet<> InactiveGlobals = { "_ZSt3cin", "_ZSt4cout", "_ZNSt3__14coutE", + "_ZNSt3__15wcoutE", "_ZNSt3__113basic_ostreamIcNS_11char_traitsIcEEE6sentryC1ERS3_", "_ZSt5wcout", "_ZSt4cerr", + "_ZNSt3__14cerrE", "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", "_ZTVSt15basic_streambufIcSt11char_traitsIcEE", "_ZTVSt9basic_iosIcSt11char_traitsIcEE", @@ -284,13 +286,17 @@ const StringSet<> KnownInactiveFunctions = { "cuDevicePrimaryCtxRetain", "floor", "floorf", - "floorl" + "floorl", + "\01_fopen", + "fopen", + "fclose", }; const std::set KnownInactiveIntrinsics = { #if LLVM_VERSION_MAJOR >= 12 Intrinsic::experimental_noalias_scope_decl, #endif + Intrinsic::objectsize, Intrinsic::floor, Intrinsic::ceil, Intrinsic::trunc, @@ -407,9 +413,16 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { "std::__1::shuffle_order_engine", "std::__1::basic_streambuf", "std::__1::basic_stringbuf", + "std::__1::basic_istream", + "std::__1::basic_filebuf", + "std::__1::basic_iostream", + "std::__1::basic_ios", + "virtual thunk to std::__1::basic_istream", + "virtual thunk to std::__1::basic_ostream", "std::__detail::_Prime_rehash_policy", "std::__detail::_Hash_code_base", + }; // clang-format on diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ced5c498b6d0..60902eafd2a1 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4400,7 +4400,6 @@ class AdjointGenerator : public llvm::InstVisitor { } } } - size_t freeCount = 0; for (auto LI : geps) { CallInst *freeCall = nullptr; for (auto LU : LI->users()) { @@ -4428,7 +4427,6 @@ class AdjointGenerator : public llvm::InstVisitor { } if (freeCall) { freeCall->eraseFromParent(); - freeCount++; } } } @@ -4599,7 +4597,7 @@ class AdjointGenerator : public llvm::InstVisitor { EmitFailure("CannotDeduceType", call.getDebugLoc(), &call, "failed to deduce type of copy ", call); } -#if LLVM_VERSION_MAJOR < 18 +#if LLVM_VERSION_MAJOR < 17 knownF: #endif unsigned start = 0; diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index c8b2cdaa4e4b..155967e24a33 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3572,7 +3572,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( ConstantInt::getFalse(call.getContext())); return true; } - if (funcName == "memset" || funcName == "memset_pattern16") { + if (funcName == "memset" || funcName == "memset_pattern16" || + funcName == "__memset_chk") { visitMemSetCommon(call); return true; } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index fbd078ce23c7..bc4856566909 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2336,7 +2336,18 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( if (todiff->empty()) { std::string s; llvm::raw_string_ostream ss(s); - ss << "No augmented forward pass found for " + todiff->getName() << "\n"; + ss << "No augmented forward pass found for " + todiff->getName(); + { + std::string demangledName = llvm::demangle(todiff->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); + } + if (demangledName != todiff->getName()) + ss << "(" << demangledName << ")"; + } + ss << "\n"; llvm::Value *toshow = todiff; if (context.req) { toshow = context.req; @@ -5889,37 +5900,47 @@ llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context, cast(CreateNoFree(context, castinst->getOperand(0)))}; return castinst->getWithOperands(reps); } - if (CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of unknown value\n"; - ss << *todiff << "\n"; - if (context.req) { - ss << " at context: " << *context.req; + if (EnzymeAssumeUnknownNoFree) { + return todiff; + } + + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No create nofree of unknown value\n"; + ss << *todiff << "\n"; + if (context.req) { + ss << " at context: " << *context.req; + } + if (auto I = dyn_cast(todiff)) { + auto fname = I->getParent()->getParent()->getName(); + if (startsWith(fname, "nofree_")) + fname = fname.substr(7); + std::string demangledName = llvm::demangle(fname.str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); } + ss << " within func " << fname << " (" << demangledName << ")\n"; + } + if (CustomErrorHandler) { CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); return todiff; } - if (EnzymeAssumeUnknownNoFree) { - return todiff; - } - if (context.req) { - EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, - "Cannot create nofree of instruction-created value: ", *todiff); + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, s); return todiff; } if (auto arg = dyn_cast(todiff)) { auto loc = arg->getDebugLoc(); - EmitFailure("IllegalNoFree", loc, arg, - "Cannot create nofree of instruction-created value: ", *todiff); + EmitFailure("IllegalNoFree", loc, arg, s); return todiff; } - llvm::errs() << " unhandled, create no free of: " << *todiff << "\n"; + llvm::errs() << s; llvm_unreachable("unhandled, create no free"); } @@ -6001,6 +6022,7 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { "std::__1::basic_string, std::__1::allocator>::data() const", "std::__1::basic_ostream>::sentry::sentry(std::__1::basic_ostream>&)", "std::__1::basic_ostream>::sentry::~sentry()", + "std::__1::basic_ostream>::flush()", "std::__1::ios_base::__set_badbit_and_consider_rethrow()", "char* std::__1::addressof(char&)", "char const* std::__1::addressof(char const&)", @@ -6011,6 +6033,40 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { "std::__1::ios_base::ios_base()", "std::__1::ios_base::getloc() const", "std::__1::ios_base::clear(unsigned int)", + "std::__1::basic_iostream>::~basic_iostream()", + "std::__1::basic_ios>::~basic_ios()", + "std::__1::basic_streambuf>::basic_streambuf()", + "std::__1::basic_streambuf>::~basic_streambuf()", + "std::__1::basic_streambuf>::imbue(std::__1::locale const&)", + "std::__1::basic_streambuf>::setbuf(char*, long)", + "std::__1::basic_streambuf>::sync()", + "std::__1::basic_streambuf>::showmanyc()", + "std::__1::basic_streambuf>::xsgetn(char*, long)", + "std::__1::basic_streambuf>::uflow()", + "std::__1::basic_filebuf>::basic_filebuf()", + "std::__1::basic_filebuf>::~basic_filebuf()", + "std::__1::basic_filebuf>::open(char const*, unsigned int)", + "std::__1::basic_filebuf>::close()", + "std::__1::basic_filebuf>::sync()", + "std::__1::basic_istream>::~basic_istream()", + "virtual thunk to std::__1::basic_istream>::~basic_istream()", + "virtual thunk to std::__1::basic_ostream>::~basic_ostream()", + "std::__1::basic_ifstream>::~basic_ifstream()", + "std::__1::ios_base::init(void*)", + "std::__1::basic_istream>::read(char*, long)", + "std::__1::basic_ostream>::~basic_ostream()", + "std::__1::basic_string, std::__1::allocator>::__init(unsigned long, char)", + "std::__1::basic_ostream>::write(char const*, long)", + }; + const char* NoFreeDemanglesStartsWith[] = { + "std::__1::basic_ostream>::operator<<", + "std::__1::ios_base::imbue", + "std::__1::basic_streambuf>::pubimbue", + "std::__1::basic_stringbuf, std::__1::allocator>::__init_buf_ptrs", + "std::__1::basic_stringbuf, std::__1::allocator>::basic_stringbuf", + "std::__1::basic_string, std::__1::allocator>::operator=", + "std::__1::ctype::widen", + "std::__1::basic_streambuf>::sputn", }; // clang-format on @@ -6032,6 +6088,10 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (NoFreeDemangles.count(demangledName)) return F; + for (auto Name : NoFreeDemanglesStartsWith) + if (startsWith(demangledName, Name)) + return F; + switch (F->getIntrinsicID()) { case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: @@ -6055,6 +6115,17 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { << F->getName() << ")\n"; if (context.req) { ss << " at context: " << *context.req; + if (auto CB = dyn_cast(context.req)) { + if (auto F = CB->getCalledFunction()) { + std::string demangleF = llvm::demangle(F->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangleF.find("> >", start)) != std::string::npos) { + demangleF.replace(start, 3, ">>"); + } + ss << " (" << demangleF << ")"; + } + } } else { ss << *F << "\n"; } diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index d5bf445ee429..c0b54a4eccbb 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4507,8 +4507,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { return; } - if (startsWith(funcName, "_ZNKSt3__112basic_stringIcNS_11char_traitsIcEENS_" - "9allocatorIcEEE13__get_pointer")) { + if (startsWith(funcName, "_ZNKSt3__112basic_string") || + startsWith(funcName, "_ZNSt3__112basic_string") || + startsWith(funcName, "_ZNSt3__112__hash_table") || + startsWith(funcName, "_ZNKSt3__115basic_stringbuf")) { return; } diff --git a/enzyme/test/Integration/ReverseMode/blas_gemm2.c b/enzyme/test/Integration/ReverseMode/blas_gemm2.c index a417e0418cd9..b06474282634 100644 --- a/enzyme/test/Integration/ReverseMode/blas_gemm2.c +++ b/enzyme/test/Integration/ReverseMode/blas_gemm2.c @@ -8,9 +8,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi -#include -#include -#include #include "../test_utils.h" #include "../blas_inline.h" diff --git a/enzyme/test/Integration/ReverseMode/customlog1p.c b/enzyme/test/Integration/ReverseMode/customlog1p.c index 95c0697fff77..6f563030f9b3 100644 --- a/enzyme/test/Integration/ReverseMode/customlog1p.c +++ b/enzyme/test/Integration/ReverseMode/customlog1p.c @@ -17,10 +17,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/inactivefn.c b/enzyme/test/Integration/ReverseMode/inactivefn.c index e48edd3486ed..bbd6696a04d9 100644 --- a/enzyme/test/Integration/ReverseMode/inactivefn.c +++ b/enzyme/test/Integration/ReverseMode/inactivefn.c @@ -17,10 +17,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/invsqrt.c b/enzyme/test/Integration/ReverseMode/invsqrt.c index c14ce79ab88d..09ada6757e45 100644 --- a/enzyme/test/Integration/ReverseMode/invsqrt.c +++ b/enzyme/test/Integration/ReverseMode/invsqrt.c @@ -17,13 +17,8 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include -#include -#include - #include "../test_utils.h" +#include // Fast inverse sqrt // Code taken from https://en.wikipedia.org/wiki/Fast_inverse_square_root @@ -74,10 +69,8 @@ int main(int argc, char *argv[]) { double *A = (double*)malloc(sizeof(double) * n); for(int i=0; i -#include -#include - #include "../test_utils.h" +double pow(double, double); + __attribute__((noinline)) -uint64_t factorial(uint64_t x) { +unsigned long long factorial(unsigned long long x) { if (x == 0) return 1; return x * factorial(x-1); } double my_sin(double x) { double result = 0; - uint64_t N = 12; - for(uint64_t i=0; i<=N; i++) { + unsigned long long N = 12; + for(unsigned long long i=0; i<=N; i++) { if (i % 2 == 0) continue; result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1); } @@ -38,14 +36,14 @@ double my_sin(double x) { } -uint64_t __enzyme_iter(uint64_t, uint64_t); +unsigned long long __enzyme_iter(unsigned long long, unsigned long long); double __enzyme_autodiff(void*, double); double my_sin2(double x) { double result = 0; - uint64_t N = __enzyme_iter(12, 1); - for(uint64_t i=0; i<=N; i++) { + unsigned long long N = __enzyme_iter(12, 1); + for(unsigned long long i=0; i<=N; i++) { if (i % 2 == 0) continue; result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1); } diff --git a/enzyme/test/Integration/ReverseMode/omp.c b/enzyme/test/Integration/ReverseMode/omp.c index 4d6f0bd164c5..5d4f8ae09970 100644 --- a/enzyme/test/Integration/ReverseMode/omp.c +++ b/enzyme/test/Integration/ReverseMode/omp.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp2.c b/enzyme/test/Integration/ReverseMode/omp2.c index 8cdd06545fd1..44f315fe28f9 100644 --- a/enzyme/test/Integration/ReverseMode/omp2.c +++ b/enzyme/test/Integration/ReverseMode/omp2.c @@ -8,10 +8,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp3.c b/enzyme/test/Integration/ReverseMode/omp3.c index 59bdca34cf9a..970348a7cedd 100644 --- a/enzyme/test/Integration/ReverseMode/omp3.c +++ b/enzyme/test/Integration/ReverseMode/omp3.c @@ -9,8 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -# include -# include #include "../test_utils.h" void msg(double* in, int *len, unsigned int slen) { diff --git a/enzyme/test/Integration/ReverseMode/omp6.c b/enzyme/test/Integration/ReverseMode/omp6.c index 8a6f34fe235f..abd170d07dfd 100644 --- a/enzyme/test/Integration/ReverseMode/omp6.c +++ b/enzyme/test/Integration/ReverseMode/omp6.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -# include -# include -#include - #include "../test_utils.h" __attribute__((noinline)) diff --git a/enzyme/test/Integration/ReverseMode/omp_two.c b/enzyme/test/Integration/ReverseMode/omp_two.c index 1f95e93f2520..8d3dac92ca75 100644 --- a/enzyme/test/Integration/ReverseMode/omp_two.c +++ b/enzyme/test/Integration/ReverseMode/omp_two.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/ompbound.c b/enzyme/test/Integration/ReverseMode/ompbound.c index 4d6f0bd164c5..5d4f8ae09970 100644 --- a/enzyme/test/Integration/ReverseMode/ompbound.c +++ b/enzyme/test/Integration/ReverseMode/ompbound.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/test_utils.h b/enzyme/test/Integration/test_utils.h index 881aa5f8f703..268922598f54 100644 --- a/enzyme/test/Integration/test_utils.h +++ b/enzyme/test/Integration/test_utils.h @@ -1,9 +1,11 @@ -#ifdef __cplusplus +#if defined(__cplusplus) || defined(__APPLE__) #include #include #include +#include +#include #else struct _IO_FILE; extern struct _IO_FILE* stderr; From 1285a39808c276ddedbe9a6fc491e7eeda64c488 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 21:51:25 -0500 Subject: [PATCH 035/106] Ensure custom derivative functions aren't deleted (#1697) --- enzyme/Enzyme/EnzymeLogic.cpp | 22 ++++++++++ enzyme/Enzyme/PreserveNVVM.cpp | 40 +++++++++++-------- .../test/Enzyme/ReverseMode/custom-sret3.ll | 2 +- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index bc4856566909..ebdbb937733d 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4490,7 +4490,29 @@ Function *EnzymeLogic::CreateForwardDiff( "unknown derivative for function -- metadata incorrect"); } auto md2 = cast(md); + assert(md2); assert(md2->getNumOperands() == 1); + if (!md2->getOperand(0)) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Failed to use custom forward mode derivative for " + << todiff->getName() << "\n"; + ss << " found metadata (but null op0) " << *md2 << "\n"; + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + return ForwardCachedFunctions[tup] = nullptr; + } + if (!isa(md2->getOperand(0))) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Failed to use custom forward mode derivative for " + << todiff->getName() << "\n"; + ss << " found metadata (but not constantasmetadata) " + << *md2->getOperand(0) << "\n"; + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + return ForwardCachedFunctions[tup] = nullptr; + } auto gvemd = cast(md2->getOperand(0)); auto foundcalled = cast(gvemd->getValue()); diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index 73d13a009abb..a7df39596f0a 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -56,6 +56,25 @@ using namespace llvm; #define addAttribute addAttributeAtIndex #endif +//! Returns whether changed. +bool preserveLinkage(bool Begin, Function &F, bool Inlining = true) { + if (Begin && !F.hasFnAttribute("prev_fixup")) { + F.addFnAttr("prev_fixup"); + if (Inlining) { + if (F.hasFnAttribute(Attribute::AlwaysInline)) + F.addFnAttr("prev_always_inline"); + F.removeFnAttr(Attribute::AlwaysInline); + if (F.hasFnAttribute(Attribute::NoInline)) + F.addFnAttr("prev_no_inline"); + F.addFnAttr(Attribute::NoInline); + } + F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); + F.setLinkage(Function::LinkageTypes::ExternalLinkage); + return true; + } + return false; +} + template static void handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, @@ -237,26 +256,31 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, Fs[fn] = NewF; } + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_augment", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); + preserveLinkage(true, *Fs[2], false); Fs[0]->setMetadata( "enzyme_gradient", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[2])})); } else if (Mode == DerivativeMode::ForwardMode) { assert(numargs == 2); + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_derivative", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); } else if (Mode == DerivativeMode::ForwardModeSplit) { assert(numargs == 3); + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_augment", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); + preserveLinkage(true, *Fs[2], false); Fs[0]->setMetadata( "enzyme_splitderivative", llvm::MDTuple::get(Fs[0]->getContext(), @@ -282,22 +306,6 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, } globalsToErase.push_back(&g); } -//! Returns whether changed. -bool preserveLinkage(bool Begin, Function &F) { - if (Begin && !F.hasFnAttribute("prev_fixup")) { - F.addFnAttr("prev_fixup"); - if (F.hasFnAttribute(Attribute::AlwaysInline)) - F.addFnAttr("prev_always_inline"); - if (F.hasFnAttribute(Attribute::NoInline)) - F.addFnAttr("prev_no_inline"); - F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); - F.setLinkage(Function::LinkageTypes::ExternalLinkage); - F.addFnAttr(Attribute::NoInline); - F.removeFnAttr(Attribute::AlwaysInline); - return true; - } - return false; -} bool preserveNVVM(bool Begin, Function &F) { bool changed = false; diff --git a/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll b/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll index 195470291e2a..da82145bdbd8 100644 --- a/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll +++ b/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll @@ -118,7 +118,7 @@ attributes #4 = { nounwind } !10 = !{!11, !11, i64 0} !11 = !{!"int", !5, i64 0} -; CHECK: define internal void @fixbyval_myblas_cdot_rev(%struct.complex* %arg0, %struct.complex* %arg1, %struct.complex* %arg2, %struct.complex* %arg3, i32 %arg4, i32 %arg5, %struct.complex %arg6, i8* %arg7) +; CHECK: define dso_local void @fixbyval_myblas_cdot_rev(%struct.complex* %arg0, %struct.complex* %arg1, %struct.complex* %arg2, %struct.complex* %arg3, i32 %arg4, i32 %arg5, %struct.complex %arg6, i8* %arg7) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = alloca %struct.complex ; CHECK-NEXT: store %struct.complex %arg6, %struct.complex* %0 From 5e453755a1affc6da3d40c4ece9dcebfe4a319ab Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 22:22:57 -0500 Subject: [PATCH 036/106] Fix memcpy of size 1 [part 2] (#1701) --- enzyme/Enzyme/ActivityAnalysis.cpp | 19 ++++++++----------- enzyme/Enzyme/AdjointGenerator.h | 8 ++++++++ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index ef6c7a14f811..79a56d885507 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -432,16 +432,16 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { if (auto iasm = dyn_cast(CI.getCalledOperand())) { if (StringRef(iasm->getAsmString()).contains("exit") || StringRef(iasm->getAsmString()).contains("cpuid")) - return false; + return true; } - auto F = getFunctionFromCall(&CI); - - if (F == nullptr) - return false; - - if (F->hasFnAttribute("enzyme_inactive")) { - return true; + if (auto F = getFunctionFromCall(&CI)) { + if (F->hasFnAttribute("enzyme_inactive")) { + return true; + } + if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { + return true; + } } auto Name = getFuncNameFromCall(&CI); @@ -478,9 +478,6 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { return true; } - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return true; - } // Copies of size 1 are inactive [cannot move differentiable data in one byte] if (auto MTI = dyn_cast(&CI)) { if (auto sz = dyn_cast(MTI->getOperand(2))) { diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 60902eafd2a1..c94f03b4bb0b 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -3275,6 +3275,14 @@ class AdjointGenerator : public llvm::InstVisitor { return; } + // memcpy of size 1 cannot move differentiable data [single byte copy] + if (auto ci = dyn_cast(new_size)) { + if (ci->getValue() == 1) { + eraseIfUnused(MTI); + return; + } + } + // copying into nullptr is invalid (not sure why it exists here), but we // shouldn't do it in reverse pass or shadow if (isa(orig_dst) || From 89d20faf5beac95fc319d82d75dc9d39ce22ac3f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 22:51:30 -0500 Subject: [PATCH 037/106] Fix asan error and add mlir alloc/free fns (#1700) * Fix asan error and add mlir alloc/free fns * fix --- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 28 ++--- enzyme/Enzyme/GradientUtils.cpp | 5 +- enzyme/Enzyme/LibraryFuncs.h | 4 + enzyme/test/Enzyme/ReverseMode/mlirmincut.ll | 107 +++++++++++++++++++ 4 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/mlirmincut.ll diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index d20ec4eefbc4..f4eb07a01602 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -849,7 +849,9 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, assert(pair.first.outgoing == 0 && N.outgoing == 1); assert(pair.first.V == N.V); MinReq.insert(N.V); - todo.push_back(N.V); + if (Orig.find(Node(N.V, true)) != Orig.end()) { + todo.push_back(N.V); + } } } } @@ -860,20 +862,20 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, auto V = todo.front(); todo.pop_front(); auto found = Orig.find(Node(V, true)); - if (found->second.size() == 1 && !Required.count(V)) { + assert(found != Orig.end()); + const auto &mp = found->second; + if (mp.size() == 1 && !Required.count(V)) { bool potentiallyRecursive = - isa((*found->second.begin()).V) && - OrigLI.isLoopHeader( - cast((*found->second.begin()).V)->getParent()); + isa((*mp.begin()).V) && + OrigLI.isLoopHeader(cast((*mp.begin()).V)->getParent()); int moreOuterLoop = cmpLoopNest( OrigLI.getLoopFor(cast(V)->getParent()), - OrigLI.getLoopFor( - cast(((*found->second.begin()).V))->getParent())); + OrigLI.getLoopFor(cast(((*mp.begin()).V))->getParent())); if (potentiallyRecursive) continue; if (moreOuterLoop == -1) continue; - if (auto ASC = dyn_cast((*found->second.begin()).V)) { + if (auto ASC = dyn_cast((*mp.begin()).V)) { if (ASC->getDestAddressSpace() == 11 || ASC->getDestAddressSpace() == 13) continue; @@ -882,7 +884,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, } // If an allocation call, we cannot cache any "capturing" users if (isAllocationCall(V, TLI)) { - auto next = (*found->second.begin()).V; + auto next = (*mp.begin()).V; bool noncapture = false; if (isa(next)) { noncapture = true; @@ -918,10 +920,12 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, if (moreOuterLoop == 1 || (moreOuterLoop == 0 && DL.getTypeSizeInBits(V->getType()) >= - DL.getTypeSizeInBits((*found->second.begin()).V->getType()))) { + DL.getTypeSizeInBits((*mp.begin()).V->getType()))) { MinReq.remove(V); - MinReq.insert((*found->second.begin()).V); - todo.push_back((*found->second.begin()).V); + auto nnode = (*mp.begin()).V; + MinReq.insert(nnode); + if (Orig.find(Node(nnode, true)) != Orig.end()) + todo.push_back(nnode); } } } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 0fd01c83a05e..b1d8161576f2 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9162,7 +9162,8 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, tofree = builder.CreateIntToPtr(tofree, getInt8PtrTy(tofree->getContext())); llvm::LibFunc libfunc; - if (allocationfn == "calloc" || allocationfn == "malloc") { + if (allocationfn == "calloc" || allocationfn == "malloc" || + allocationfn == "_mlir_memref_to_llvm_alloc") { libfunc = LibFunc_malloc; } else { bool res = TLI.getLibFunc(allocationfn, libfunc); @@ -9232,6 +9233,8 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, if (freename != "free") llvm_unreachable("illegal free"); } + if (allocationfn == "_mlir_memref_to_llvm_alloc") + freename = "_mlir_memref_to_llvm_free"; Type *VoidTy = Type::getVoidTy(tofree->getContext()); Type *IntPtrTy = getInt8PtrTy(tofree->getContext()); diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index 1a8ebce38eee..d18ecc346802 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -49,6 +49,8 @@ static inline bool isAllocationFunction(const llvm::StringRef name, return true; if (name == "calloc" || name == "malloc") return true; + if (name == "_mlir_memref_to_llvm_alloc") + return true; if (name == "swift_allocObject") return true; if (name == "__rust_alloc" || name == "__rust_alloc_zeroed") @@ -123,6 +125,8 @@ static inline bool isDeallocationFunction(const llvm::StringRef name, if (!TLI.getLibFunc(name, libfunc)) { if (name == "free") return true; + if (name == "_mlir_memref_to_llvm_free") + return true; if (name == "__rust_dealloc") return true; if (name == "swift_release") diff --git a/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll new file mode 100644 index 000000000000..6f214ddb2cdc --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll @@ -0,0 +1,107 @@ +; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi + +declare void @__enzyme_autodiff0(...) local_unnamed_addr + +declare void @_mlir_memref_to_llvm_free(ptr) + +declare ptr @_mlir_memref_to_llvm_alloc(i64) + +define void @jit_compiled(ptr %a) { + tail call void (...) @__enzyme_autodiff0(ptr nonnull @f, metadata !"enzyme_const", ptr %a, ptr %a, ptr %a, i64 0, metadata !"enzyme_const", ptr %a, metadata !"enzyme_const", ptr %a, i64 0, metadata !"enzyme_const", i64 1, metadata !"enzyme_const", ptr %a, metadata !"enzyme_dupnoneed", ptr %a, ptr %a, i64 0) + ret void +} + +define void @f(ptr %arg, ptr %arg1, i64 %arg2, ptr %arg3, ptr %arg4, i64 %arg5, i64 %arg6, ptr nocapture readnone %arg7, ptr nocapture writeonly %arg8, i64 %arg9) { + %.idx = shl i64 %arg6, 3 + %i = tail call ptr @_mlir_memref_to_llvm_alloc(i64 %.idx) + %i10 = load double, ptr %arg4, align 8 + %i11 = fcmp ogt double %i10, 1.500000e+00 + %i12 = load double, ptr %arg1, align 8 + %i13 = fmul double %i12, 2.000000e+00 + %storemerge = select i1 %i11, double %i13, double %i12 + store double %storemerge, ptr %i, align 8 + %i14 = alloca { ptr, ptr, i64 }, align 8 + store ptr %arg, ptr %i14, align 8 + %.repack2 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 1 + store ptr %arg1, ptr %.repack2, align 8 + %.repack4 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 2 + store i64 %arg2, ptr %.repack4, align 8 + %i15 = alloca { ptr, ptr, i64 }, align 8 + store ptr %arg3, ptr %i15, align 8 + %.repack6 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 1 + store ptr %arg4, ptr %.repack6, align 8 + %.repack8 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 2 + store i64 %arg5, ptr %.repack8, align 8 + %i16 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8 + store ptr %i, ptr %i16, align 8 + %.repack10 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 1 + store ptr %i, ptr %.repack10, align 8 + %.repack12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 2 + store i64 0, ptr %.repack12, align 8 + %.repack14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 3 + store i64 %arg6, ptr %.repack14, align 8 + %.repack16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 4 + store i64 1, ptr %.repack16, align 8 + %i17 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) + %i18 = alloca { ptr, ptr, i64 }, align 8 + store ptr %i17, ptr %i18, align 8 + %.repack18 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 1 + store ptr %i17, ptr %.repack18, align 8 + %.repack20 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 2 + store i64 0, ptr %.repack20, align 8 + %i19 = load double, ptr %i17, align 8 + store double %i19, ptr %arg8, align 8 + ret void +} + +; CHECK: define internal void @diffef(ptr %arg, ptr %arg1, ptr %"arg1'", i64 %arg2, ptr %arg3, ptr %arg4, i64 %arg5, i64 %arg6, ptr nocapture readnone %arg7, ptr nocapture writeonly %arg8, ptr nocapture %"arg8'", i64 %arg9) +; CHECK-NEXT: invert: +; CHECK-NEXT: %.idx = shl i64 %arg6, 3 +; CHECK-NEXT: %i = tail call ptr @_mlir_memref_to_llvm_alloc(i64 %.idx) +; CHECK-NEXT: %i10 = load double, ptr %arg4, align 8 +; CHECK-NEXT: %i11 = fcmp ogt double %i10, 1.500000e+00 +; CHECK-NEXT: %i12 = load double, ptr %arg1, align 8 +; CHECK-NEXT: %i13 = fmul double %i12, 2.000000e+00 +; CHECK-NEXT: %storemerge = select i1 %i11, double %i13, double %i12 +; CHECK-NEXT: store double %storemerge, ptr %i, align 8 +; CHECK-NEXT: %i14 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %arg, ptr %i14, align 8 +; CHECK-NEXT: %.repack2 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 1 +; CHECK-NEXT: store ptr %arg1, ptr %.repack2, align 8 +; CHECK-NEXT: %.repack4 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 2 +; CHECK-NEXT: store i64 %arg2, ptr %.repack4, align 8 +; CHECK-NEXT: %i15 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %arg3, ptr %i15, align 8 +; CHECK-NEXT: %.repack6 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 1 +; CHECK-NEXT: store ptr %arg4, ptr %.repack6, align 8 +; CHECK-NEXT: %.repack8 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 2 +; CHECK-NEXT: store i64 %arg5, ptr %.repack8, align 8 +; CHECK-NEXT: %i16 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8 +; CHECK-NEXT: store ptr %i, ptr %i16, align 8 +; CHECK-NEXT: %.repack10 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 1 +; CHECK-NEXT: store ptr %i, ptr %.repack10, align 8 +; CHECK-NEXT: %.repack12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 2 +; CHECK-NEXT: store i64 0, ptr %.repack12, align 8 +; CHECK-NEXT: %.repack14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 3 +; CHECK-NEXT: store i64 %arg6, ptr %.repack14, align 8 +; CHECK-NEXT: %.repack16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 4 +; CHECK-NEXT: store i64 1, ptr %.repack16, align 8 +; CHECK-NEXT: %"i17'mi" = tail call noalias nonnull ptr @_mlir_memref_to_llvm_alloc(i64 8) +; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr nonnull dereferenceable(8) dereferenceable_or_null(8) %"i17'mi", i8 0, i64 8, i1 false) +; CHECK-NEXT: %i17 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) +; CHECK-NEXT: %i18 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %i17, ptr %i18, align 8 +; CHECK-NEXT: %.repack18 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 1 +; CHECK-NEXT: store ptr %i17, ptr %.repack18, align 8 +; CHECK-NEXT: %.repack20 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 2 +; CHECK-NEXT: store i64 0, ptr %.repack20, align 8 +; CHECK-NEXT: %0 = load double, ptr %"arg8'", align 8 +; CHECK-NEXT: store double 0.000000e+00, ptr %"arg8'", align 8 +; CHECK-NEXT: %1 = load double, ptr %"i17'mi", align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %0 +; CHECK-NEXT: store double %2, ptr %"i17'mi", align 8 +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr nonnull %"i17'mi") +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr %i17) +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr %i) +; CHECK-NEXT: ret void +; CHECK-NEXT: } From 32c3786ceff49a285c6e1d7e548687b0203fd57c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 23:30:54 -0500 Subject: [PATCH 038/106] Erase atomic on err (#1702) --- enzyme/Enzyme/AdjointGenerator.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index c94f03b4bb0b..ae2e72aceb83 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -898,6 +898,9 @@ class AdjointGenerator : public llvm::InstVisitor { setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), BuilderZ); } + gutils->replaceAWithB(gutils->getNewFromOriginal(&I), + UndefValue::get(I.getType())); + eraseIfUnused(I, /*erase*/ true, /*check*/ false); return; } From 012da6b546597bec8f345129857939e156564c0e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Feb 2024 23:48:40 -0500 Subject: [PATCH 039/106] Preserve parmtype (#1703) --- enzyme/Enzyme/FunctionUtils.cpp | 92 +++++++++++++++++++-------------- 1 file changed, 54 insertions(+), 38 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 71ceb9688eea..6bc443276b4b 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2194,10 +2194,14 @@ Function *PreProcessCache::CloneFunctionWithReturns( // Attribute::ElementType)); #endif } - for (auto ty : PrimalParamAttrsToPreserve) - if (F->getAttributes().hasParamAttr(ii, ty)) { - auto attr = F->getAttributes().getParamAttr(ii, ty); - NewF->addParamAttr(jj, attr); + for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (F->getAttributes().hasParamAttr(ii, attr)) { + NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr)); + for (auto ty : PrimalParamAttrsToPreserve) + if (F->getAttributes().hasParamAttr(ii, ty)) { + auto attr = F->getAttributes().getParamAttr(ii, ty); + NewF->addParamAttr(jj, attr); + } } if (constant_args[ii] == DIFFE_TYPE::CONSTANT) { if (!i->hasByValAttr()) @@ -2212,8 +2216,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( << " nonconstant arg " << *j << "\n"; } - // Always remove nonnull/noundef since the caller may choose to pass undef - // as an arg if provably it will not be used in the reverse pass + // Always remove nonnull/noundef since the caller may choose to pass + // undef as an arg if provably it will not be used in the reverse pass if (constant_args[ii] == DIFFE_TYPE::DUP_NONEED || mode == DerivativeMode::ReverseModeGradient) { if (F->hasParamAttribute(ii, Attribute::NonNull)) { @@ -2236,6 +2240,13 @@ Function *PreProcessCache::CloneFunctionWithReturns( NewF->addParamAttr(jj + 1, attr); } + for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (F->getAttributes().hasParamAttr(ii, attr)) { + if (width == 1) + NewF->addParamAttr(jj + 1, + F->getAttributes().getParamAttr(ii, attr)); + } + if (F->getAttributes().hasParamAttr(ii, "enzymejl_returnRoots")) { if (width == 1) { NewF->addParamAttr(jj + 1, F->getAttributes().getParamAttr( @@ -2247,7 +2258,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( #if LLVM_VERSION_MAJOR >= 13 // TODO // NewF->addParamAttr(jj + 1, - // F->getParamAttribute(ii, Attribute::ElementType)); + // F->getParamAttribute(ii, + // Attribute::ElementType)); #endif } @@ -2266,7 +2278,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // jj + 1, // Attribute::get(F->getContext(), // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, Attribute::StructRet) + // F->getParamAttribute(ii, + // Attribute::StructRet) // .getValueAsType())); #endif } else { @@ -2283,7 +2296,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // jj + 1, // Attribute::get(F->getContext(), // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, Attribute::StructRet) + // F->getParamAttribute(ii, + // Attribute::StructRet) // .getValueAsType())); #endif } @@ -3654,14 +3668,13 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } /* - // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 != - c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; j++) - if (auto c0 = dyn_cast(cur->getOperand(j))) - if (auto cmp0 = dyn_cast(c0->getOperand(0))) - if (auto c1 = dyn_cast(cur->getOperand(1-j))) - if (auto cmp1 = dyn_cast(c0->getOperand(0))) - if (cmp0->getPredicate() == ICmpInst::ICMP_EQ && - cmp1->getPredicate() == ICmpInst::ICMP_EQ) + // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 + != c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; + j++) if (auto c0 = dyn_cast(cur->getOperand(j))) if (auto cmp0 = + dyn_cast(c0->getOperand(0))) if (auto c1 = + dyn_cast(cur->getOperand(1-j))) if (auto cmp1 = + dyn_cast(c0->getOperand(0))) if (cmp0->getPredicate() == + ICmpInst::ICMP_EQ && cmp1->getPredicate() == ICmpInst::ICMP_EQ) { for (size_t i0 = 0; i0 < 2; i0++) for (size_t i1 = 0; i1 < 2; i1++) @@ -3694,7 +3707,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (auto C = dyn_cast(fcmp->getOperand(i))) { if (C->isZero()) { // (a1*a2*...an) == 0 -> (a1 == 0) || (a2 == 0) || ... (a2 == 0) - // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == 0) + // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == + // 0) // ] if (auto P = isProduct(fcmp->getOperand(1 - i))) { Value *res = nullptr; @@ -3885,8 +3899,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, cur->isExact()) if (auto C2 = dyn_cast(cur->getOperand(1))) if (auto mul = dyn_cast(cur->getOperand(0))) { - // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 - // divides C1 + // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if + // C2 divides C1 if (mul->getOpcode() == Instruction::Mul) for (int i0 = 0; i0 < 2; i0++) if (auto C1 = dyn_cast(mul->getOperand(i0))) { @@ -3913,7 +3927,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return "IMulDivConst"; } } - // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if C2 + // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if + // C2 if (mul->getOpcode() == Instruction::Add) for (int i0 = 0; i0 < 2; i0++) if (auto C1 = dyn_cast(mul->getOperand(i0))) { @@ -4123,8 +4138,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, // (a * b) != (c * b) -> (a != c) && b != 0 // auto S1 = SE.getSCEV(cur->getOperand(0)); // auto S2 = SE.getSCEV(cur->getOperand(1)); - // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << " S2: - // " << *S2 << " and " << *cur->getOperand(0) << " " << + // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << " + // S2: " << *S2 << " and " << *cur->getOperand(0) << " " << // *cur->getOperand(1) << "\n"; if (auto mul1 = dyn_cast(cur->getOperand(0))) if (auto mul2 = dyn_cast(cur->getOperand(1))) { @@ -4414,10 +4429,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (int j=0; j<2; j++) if (auto CI = dyn_cast(SI->getOperand(1+j))) if (CI->isZero()) { - auto tval = (j == 0) ? CI : pushcse(B.CreateMul(SI->getTrueValue(), - cur->getOperand(1-i), "tval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap())); - auto fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(), + auto tval = (j == 0) ? CI : + pushcse(B.CreateMul(SI->getTrueValue(), cur->getOperand(1-i), "tval." + + cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); auto + fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(), cur->getOperand(1-i), "fval." + cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); @@ -4488,9 +4503,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, (and1->getType()->isIntegerTy(1) && and2->getType()->isIntegerTy(1) && and1->getOpcode() == Instruction::And && and2->getOpcode() == Instruction::And) { bool done = false; for (int i1=0; i1<2; i1++) for (int - i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto c1 - = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x = - pushcse(B.CreateZExt(x, inst1->getType())); auto y = and2->getOperand(1-i2); + i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto + c1 = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x = + pushcse(B.CreateZExt(x, inst1->getType())); auto y = + and2->getOperand(1-i2); y = pushcse(B.CreateZExt(y, inst2->getType())); @@ -5181,8 +5197,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } - // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), (sitofp - // b) + // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), + // (sitofp b) if (cur->getOpcode() == Instruction::FMul && cur->isFast()) { for (int i = 0; i < 2; i++) @@ -5407,8 +5423,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } Value *sel = pushcse( - B.CreateSelect(condition, ConstantFP::get(cur->getType(), 0.0), - fmul, "mulcsi." + cur->getName())); + B.CreateSelect(condition, ConstantFP::get(cur->getType(), + 0.0), fmul, "mulcsi." + cur->getName())); replaceAndErase(cur, sel); return "FMulSIToFPProp"; @@ -6535,8 +6551,8 @@ return true; auto div = ctx.SE.getUDivExpr(MinusX, Y); auto div_e = ctx.SE.getUDivExactExpr(MinusX, Y); - // in case of inexact division, check that these exactly equal for - // replacement + // in case of inexact division, check that these exactly equal + // for replacement if (div == div_e) { if (isEqual) { @@ -6807,8 +6823,8 @@ return true; if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) { return rhs->andB(shared_from_this(), ctx); } - // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and - // (c or e)) + // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) + // and (c or e)) if (ty == Type::Union && rhs->ty == Type::Union) { if (*this == *rhs->notB(ctx)) { return Constraints::none(); From dd9f977bce5aef059a723d13f9bb2b22e363fd67 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 00:13:06 -0500 Subject: [PATCH 040/106] Handle constant of vector of i1 (#1704) --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index c0b54a4eccbb..9ff5ba6384aa 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -747,6 +747,9 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA, delete g2; int Off = (int)ai.getLimitedValue(); + if (auto VT = dyn_cast(Val->getType())) + if (VT->getElementType()->isIntegerTy(1)) + Off = i / 8; getConstantAnalysis(Op, TA, analysis); auto mid = analysis[Op]; From 82a6393901cb1110cfbc8e1879c6a8caaaa8224a Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 15:25:08 -0500 Subject: [PATCH 041/106] Ensure alloca works well with minCut aliasing (#1705) --- enzyme/Enzyme/DifferentialUseAnalysis.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index f4eb07a01602..8f2e27b3140e 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -883,7 +883,7 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, continue; } // If an allocation call, we cannot cache any "capturing" users - if (isAllocationCall(V, TLI)) { + if (isAllocationCall(V, TLI) || isa(V)) { auto next = (*mp.begin()).V; bool noncapture = false; if (isa(next)) { From a36ff5c39b131ed393fd737e7dc6897cf8948d4e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 16:57:21 -0500 Subject: [PATCH 042/106] Strengthen capturing alloc check on jlcall (#1706) --- enzyme/Enzyme/GradientUtils.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index b1d8161576f2..d1dedbc70353 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -9304,6 +9304,16 @@ bool GradientUtils::needsCacheWholeAllocation( if (idx < CI->getNumArgOperands()) #endif { + + // Calling a non-empty function with a julia base object, this is fine. + // as GC will deal with any issues with. + if (auto PT = dyn_cast(CI->getArgOperand(idx)->getType())) + if (PT->getAddressSpace() == 10) + if (EnzymeJuliaAddrLoad) + if (auto F = getFunctionFromCall(CI)) + if (!F->empty()) + continue; + if (isNoCapture(CI, idx)) continue; From 437e0526ebe856c43d613d71c28480efdfc5fed6 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 18:09:28 -0500 Subject: [PATCH 043/106] BLAS fix gemm overwrite (#1707) --- enzyme/Enzyme/BlasDerivatives.td | 10 +-- .../test/Enzyme/ReverseMode/blas/gemm_f_c.ll | 8 +- .../blas/gemm_f_c_lacpy_runtime_act.ll | 6 +- .../Enzyme/ReverseMode/blas/gemm_f_c_loop.ll | 4 +- .../Enzyme/ReverseMode/blas/gemm_f_c_split.ll | 4 +- .../ReverseMode/blas/gemm_f_c_split_lacpy.ll | 4 +- .../blas/gemm_f_c_split_transpose_lacpy.ll | 4 +- .../blas/gemm_f_c_transpose_lacpy.ll | 8 +- .../ReverseMode/blas/gemm_f_change_ld.ll | 4 +- enzyme/test/Integration/ReverseMode/blas.cpp | 77 +++++++++++++++++++ 10 files changed, 103 insertions(+), 26 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 01894e87e453..f031aa7f191c 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -229,15 +229,15 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A [ /* alpha */ (Seq<["AB", "product", "m", "n"]> - (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n + (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)), /* A */ (b<"gemm"> $layout, (Rows $transa, (Concat $transa, transpose<"transb">, $m, $k), (Concat $transb, $transa, $k, $m)), $n, $alpha, (Rows $transa, - (Concat adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n)), - (Concat $B, (ld $B, $transb, $ldb, $k, $n), adj<"C">)), + (Concat adj<"C">, $B, (ld $B, $transb, $ldb, $n, $k)), + (Concat $B, (ld $B, $transb, $ldb, $n, $k), adj<"C">)), Constant<"1.0">, adj<"A">), /* B */ (b<"gemm"> $layout, (Rows $transb, @@ -245,8 +245,8 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A (Concat $transb, $transa, $n, $k)), $m, $alpha, (Rows $transb, - (Concat $A, (ld $A, $transa, $lda, $m, $k), adj<"C">), - (Concat adj<"C">, $A, (ld $A, $transa, $lda, $m, $k))), + (Concat $A, (ld $A, $transa, $lda, $k, $m), adj<"C">), + (Concat adj<"C">, $A, (ld $A, $transa, $lda, $k, $m))), Constant<"1.0">, adj<"B">), /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">), /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, Alloca<1>) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll index e84c12e14701..633feb70203a 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll @@ -212,12 +212,12 @@ entry: ; CHECK-NEXT: %[[r65:.+]] = icmp eq i8 %loaded.trans4, 78 ; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans4, 110 ; CHECK-NEXT: %[[r67:.+]] = or i1 %[[r66]], %[[r65]] -; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r67]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r67]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r69:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[r71:.+]] = or i1 %[[r70]], %[[r69]] -; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r73:.+]] = icmp eq i8 %ld.row.trans6, 110 ; CHECK-NEXT: %[[r74:.+]] = icmp eq i8 %ld.row.trans6, 78 @@ -251,12 +251,12 @@ entry: ; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %[[loaded_trans10]], 78 ; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %[[loaded_trans10]], 110 ; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] -; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans11:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %[[loaded_trans11]], 78 ; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %[[loaded_trans11]], 110 ; CHECK-NEXT: %[[r93:.+]] = or i1 %[[r92]], %[[r91]] -; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans12:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r95:.+]] = icmp eq i8 %[[ld_row_trans12]], 110 ; CHECK-NEXT: %[[r96:.+]] = icmp eq i8 %[[ld_row_trans12]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index 8f1c4f3e566e..78fc4959f21e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -152,7 +152,7 @@ entry: ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans7, 78 ; CHECK-DAG: %[[i42:.+]] = icmp eq i8 %loaded.trans7, 110 ; CHECK-NEXT: %[[i43:.+]] = or i1 %[[i42]], %[[i41]] -; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) @@ -268,12 +268,12 @@ entry: ; CHECK-NEXT: %[[r83:.+]] = icmp eq i8 %[[loaded_trans14]], 78 ; CHECK-NEXT: %[[r84:.+]] = icmp eq i8 %[[loaded_trans14]], 110 ; CHECK-NEXT: %[[r85:.+]] = or i1 %[[r84]], %[[r83]] -; CHECK-NEXT: %[[r86:.+]] = select i1 %[[r85]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r86:.+]] = select i1 %[[r85]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans15:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %[[loaded_trans15]], 78 ; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %[[loaded_trans15]], 110 ; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] -; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans16:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %[[ld_row_trans16]], 110 ; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %[[ld_row_trans16]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll index c66d4ff88d63..606ecb1e1c5b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll @@ -271,12 +271,12 @@ entry: ; CHECK-NEXT: %[[r62:.+]] = icmp eq i8 %loaded.trans30, 78 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %loaded.trans30, 110 ; CHECK-NEXT: %[[r64:.+]] = or i1 %[[r63]], %[[r62]] -; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %cast.k, i8* %[[r37]] +; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %[[r37]], i8* %cast.k ; CHECK-NEXT: %loaded.trans31 = load i8, i8* %byref.transa, align 1 ; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans31, 78 ; CHECK-NEXT: %[[r67:.+]] = icmp eq i8 %loaded.trans31, 110 ; CHECK-NEXT: %[[r68:.+]] = or i1 %[[r67]], %[[r66]] -; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r68]], i8* %cast.k, i8* %[[r37]] +; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r68]], i8* %[[r37]], i8* %cast.k ; CHECK-NEXT: %ld.row.trans32 = load i8, i8* %byref.transb, align 1 ; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %ld.row.trans32, 110 ; CHECK-NEXT: %[[r71:.+]] = icmp eq i8 %ld.row.trans32, 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll index 0e30a9fdc7da..48589c0dd732 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll @@ -265,12 +265,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[loaded_trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[loaded_trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[ld_row_trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[ld_row_trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index 64d862b240ff..132e5c14ba4d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -239,12 +239,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index f7164d4926e8..c3a424a956a3 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -237,12 +237,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[loaded_trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[loaded_trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[ld_row_trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[ld_row_trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index d0f9ed397b0d..ffc9f795f395 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -156,12 +156,12 @@ entry: ; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[i42:.+]] = or i1 %[[i41]], %[[i40]] -; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans6 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a49:.+]] = icmp eq i8 %loaded.trans6, 78 ; CHECK-NEXT: %[[a50:.+]] = icmp eq i8 %loaded.trans6, 110 ; CHECK-NEXT: %[[a51:.+]] = or i1 %[[a50]], %[[a49]] -; CHECK-NEXT: %[[a52:.+]] = select i1 %[[a51]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[a52:.+]] = select i1 %[[a51]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans7 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a53:.+]] = icmp eq i8 %ld.row.trans7, 110 ; CHECK-NEXT: %[[a54:.+]] = icmp eq i8 %ld.row.trans7, 78 @@ -195,12 +195,12 @@ entry: ; CHECK-DAG: %[[i54:.+]] = icmp eq i8 %[[cachedtrans2]], 78 ; CHECK-DAG: %[[i55:.+]] = icmp eq i8 %[[cachedtrans2]], 110 ; CHECK-NEXT: %[[i56:.+]] = or i1 %[[i55]], %[[i54]] -; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans12:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a71:.+]] = icmp eq i8 %[[loaded_trans12]], 78 ; CHECK-NEXT: %[[a72:.+]] = icmp eq i8 %[[loaded_trans12]], 110 ; CHECK-NEXT: %[[a73:.+]] = or i1 %[[a72]], %[[a71]] -; CHECK-NEXT: %[[a74:.+]] = select i1 %[[a73]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[a74:.+]] = select i1 %[[a73]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans13:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a75:.+]] = icmp eq i8 %[[ld_row_trans13]], 110 ; CHECK-NEXT: %[[a76:.+]] = icmp eq i8 %[[ld_row_trans13]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index 09189c1fe8d1..0c50fa797b3b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -135,12 +135,12 @@ entry: ; CHECK-DAG: %[[r16:.+]] = icmp eq i8 %loaded.trans1, 78 ; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans1, 110 ; CHECK-NEXT: %[[r18:.+]] = or i1 %[[r17]], %[[r16]] -; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans2 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a38:.+]] = icmp eq i8 %loaded.trans2, 78 ; CHECK-NEXT: %[[a39:.+]] = icmp eq i8 %loaded.trans2, 110 ; CHECK-NEXT: %[[a40:.+]] = or i1 %[[a39]], %[[a38]] -; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a42:.+]] = icmp eq i8 %ld.row.trans3, 110 ; CHECK-NEXT: %[[a43:.+]] = icmp eq i8 %ld.row.trans3, 78 diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 1c01178b874f..7cc7438cc932 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -42,6 +42,11 @@ void my_dgemm(char layout, char transA, char transB, int M, int N, int K, double inDerivative = true; } +void ow_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { + cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + inDerivative = true; +} + static void dotTests() { @@ -374,6 +379,78 @@ static void gemmTests() { // should be the same). checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + Test = "GEMM overwrite"; + + init(); + __enzyme_autodiff((void*) ow_dgemm, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, transB, + enzyme_const, M, + enzyme_const, N, + enzyme_const, K, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + assert(foundCalls.size() > 2); + auto A_cache = (double*)foundCalls[0].pout_arg1; + cblas_dlacpy(layout, '\0', (!transA_bool) ? M : K, (!transA_bool) ? K : M, A, lda, A_cache, (!transA_bool) ? M : K); + inputs[4] = BlasInfo(A_cache, layout, (!transA_bool) ? M : K, (!transA_bool) ? K : M, (!transA_bool) ? M : K); + auto B_cache = (double*)foundCalls[1].pout_arg1; + cblas_dlacpy(layout, '\0', (!transB_bool) ? K : N, (!transB_bool) ? N : K, B, incB, B_cache, (!transB_bool) ? K : N); + inputs[5] = BlasInfo(B_cache, layout, (!transB_bool) ? K : N, (!transB_bool) ? N : K, (!transB_bool) ? K : N); + + ow_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + // dA = + my_dgemm(layout, + transA_bool ? (char)transB : (char)CBLAS_TRANSPOSE::CblasNoTrans, + transA_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transB), + transA_bool ? K : M, + transA_bool ? M : K, + N, + alpha, + transA_bool ? B_cache : dC, + transA_bool ? ( (!transB_bool) ? K : N ) : incC, + transA_bool ? dC : B_cache, + transA_bool ? incC : ( (!transB_bool) ? K : N), + 1.0, dA, lda); + + // dB = + my_dgemm(layout, + transB_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transA), + transB_bool ? (char)transA : (char)CBLAS_TRANSPOSE::CblasNoTrans, //transB, + transB_bool ? N : K, + transB_bool ? K : N, + M, + alpha, + transB_bool ? dC : A_cache, + transB_bool ? incC : ( (!transA_bool) ? M : K), + transB_bool ? A_cache : dC, + transB_bool ? ( (!transA_bool) ? M : K) : incC, + 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); } From e014a85f415189d097a4a5b2d8bbd49dcd5ffb8f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 20:07:53 -0500 Subject: [PATCH 044/106] BLAS fix vector mode (#1708) --- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index c0e725eaf948..e8535370f5b6 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -855,7 +855,7 @@ void emit_fwd_rewrite_rules(const TGPattern &pattern, raw_ostream &os) { if (ty == ArgType::fp) { const auto name = nameVec[inputType.first]; os << " Value *d_" << name - << " = llvm::ConstantFP::get(fpType, 0.0);\n"; + << " = Constant::getNullValue(gutils->getShadowType(fpType));\n"; } } From 82f8cdf207cdb7b847ef445a1061fb44b5362f26 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 12 Feb 2024 21:32:52 -0500 Subject: [PATCH 045/106] BCLoad: Still ignore if cblas lowering required (#1709) * BCLoad: Still ignore if cblas lowering required * inform which were replaced --- enzyme/BCLoad/BCLoader.cpp | 46 ++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index 6aad07eb6739..02a8abb9223f 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -24,28 +24,30 @@ static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) { #endif // LLVM_VERSION_MAJOR } -bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { +bool provideDefinitions(Module &M, std::set ignoreFunctions, + std::vector &replaced) { std::vector todo; bool seen32 = false; bool seen64 = false; for (auto &F : M) { if (!F.empty()) continue; + if (ignoreFunctions.count(F.getName().str())) + continue; int index = 0; for (auto postfix : {"", "_", "_64_"}) { std::string str; if (strlen(postfix) == 0) { str = F.getName().str(); - if (ignoreFunctions.count(str)) continue; } else if (endsWith(F.getName(), postfix)) { auto blasName = F.getName().substr(0, F.getName().size() - strlen(postfix)).str(); - if (ignoreFunctions.count(blasName)) continue; str = "cblas_" + blasName; } auto found = EnzymeBlasBC.find(str); if (found != EnzymeBlasBC.end()) { + replaced.push_back(F.getName().str()); todo.push_back(found->second); if (index == 1) seen32 = true; @@ -81,13 +83,23 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { }); #endif - if (!BC) + if (!BC) { Err.print("bcloader", llvm::errs()); + continue; + } assert(BC); SmallVector toReplace; for (auto &F : *BC) { if (F.empty()) continue; + if (ignoreFunctions.count(F.getName().str())) { +#if LLVM_VERSION_MAJOR >= 16 + F.erase(F.begin(), F.end()); +#else + F.getBasicBlockList().erase(F.begin(), F.end()); +#endif + continue; + } toReplace.push_back(F.getName().str()); } BC->setTargetTriple(""); @@ -106,12 +118,29 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { extern "C" { uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M, char **FncsNamesToIgnore, - size_t numFncNames) { + size_t numFncNames, const char ***foundP, + size_t *foundLen) { std::set ignoreFunctions = {}; for (size_t i = 0; i < numFncNames; i++) { ignoreFunctions.insert(std::string(FncsNamesToIgnore[i])); } - return provideDefinitions(*unwrap(M), ignoreFunctions); + std::vector replaced; + auto res = provideDefinitions(*unwrap(M), ignoreFunctions, replaced); + + const char **found = nullptr; + if (replaced.size()) { + found = (const char **)malloc(replaced.size() * sizeof(const char **)); + for (size_t i = 0; i < replaced.size(); i++) { + char *data = (char *)malloc(replaced[i].size() + 1); + memcpy(data, replaced[i].data(), replaced[i].size()); + data[replaced[i].size()] = 0; + found[i] = data; + } + } + *foundP = found; + *foundLen = replaced.size(); + + return res; } } @@ -121,7 +150,10 @@ class BCLoader final : public ModulePass { static char ID; BCLoader() : ModulePass(ID) {} - bool runOnModule(Module &M) override { return provideDefinitions(M, {}); } + bool runOnModule(Module &M) override { + std::vector replaced; + return provideDefinitions(M, {}, replaced); + } }; } // namespace From 87a7b5dfef3543c0dbd2dbeb8e39f4f19357e291 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 12:20:10 -0500 Subject: [PATCH 046/106] BLAS fix erasure (#1711) --- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 30 ++++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index e8535370f5b6..39cb55384ad9 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -146,14 +146,22 @@ void emit_handleBLAS(ArrayRef blasPatterns, raw_ostream &os) { << " llvm::errs() << \" fallback?\\n\"; \n" << " return false; \n" << " } \n" - << " } \n" - << " \n" - << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" - << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" << " } else { \n" - << " eraseIfUnused(call); \n" - << " } \n" - << " \n" + << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" + << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" + << " } else { \n" + << " eraseIfUnused(call); \n" + << " } \n" + << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n" + << " gutils->knownRecomputeHeuristic.end()) {\n" + << " if (!gutils->knownRecomputeHeuristic[&call]) {\n" + << " auto newCall = gutils->getNewFromOriginal(&call);\n" + << " llvm::IRBuilder<> BuilderZ(newCall);\n" + << " gutils->cacheForReverse(BuilderZ, newCall,\n" + << " getIndex(&call, CacheType::Self, BuilderZ));\n" + << " }\n" + << " }\n" + << " }\n" << " return result; \n" << "} \n"; } @@ -229,10 +237,16 @@ void emit_free_and_ending(const TGPattern &pattern, raw_ostream &os) { os << " }\n" << " }\n" + << " \n" + << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" + << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" + << " } else { \n" + << " eraseIfUnused(call); \n" + << " } \n" << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n" << " gutils->knownRecomputeHeuristic.end()) {\n" << " if (!gutils->knownRecomputeHeuristic[&call]) {\n" - << " gutils->cacheForReverse(BuilderZ, newCall,\n" + << " auto cv = gutils->cacheForReverse(BuilderZ, newCall,\n" << " getIndex(&call, CacheType::Self, BuilderZ));\n" << " }\n" << " }\n" From 6ef7d1e134c708b8453f61a03ccd69a934b0f09e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 12:56:06 -0500 Subject: [PATCH 047/106] Drop all fn references on bc (#1710) --- enzyme/BCLoad/BCLoader.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index 02a8abb9223f..286bfc2e421e 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -93,6 +93,7 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions, if (F.empty()) continue; if (ignoreFunctions.count(F.getName().str())) { + F.dropAllReferences(); #if LLVM_VERSION_MAJOR >= 16 F.erase(F.begin(), F.end()); #else From 4623cb1e530f27b20c87cffae8efcc66de4c2dd3 Mon Sep 17 00:00:00 2001 From: Max Aehle Date: Tue, 13 Feb 2024 18:57:05 +0100 Subject: [PATCH 048/106] Support more C math.h functions (#1605) * Add testcases for coshf, coshl in the forward and reverse mode. - The coshf tests pass. - In the coshl tests, the opt call crashes. * Define derivatives of sinhl, coshl, tanhl * Fix frexpl and add tests * Define derivatives of modf, modff, modfl * Add tests for modf, modff, modfl * Activity and type analysis for modf, modff, modfl * clang-format * Define derivatives of fdim, fdimf, fdiml * Add tests for fdim, fdimf, fdiml * Define derivatives for inverse hyperbolic functions asinh, asinhf, asinhl, acosh, acoshf, acoshl, atanh, atanhf, atanhl * Fix nearbyint, nearbyintf, nearbyintl * Define derivatives for erff, erfl, erfcf, erfcl * Add tests for inverse hyperbolic functions --- enzyme/Enzyme/ActivityAnalysis.cpp | 3 +- enzyme/Enzyme/InstructionDerivatives.td | 75 +++++++- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 3 + enzyme/test/Enzyme/ForwardMode/acosh.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/asinh.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/asinhf.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/asinhl.ll | 31 ++++ enzyme/test/Enzyme/ForwardMode/atanh.ll | 30 ++++ enzyme/test/Enzyme/ForwardMode/coshf.ll | 28 +++ enzyme/test/Enzyme/ForwardMode/coshl.ll | 28 +++ enzyme/test/Enzyme/ForwardMode/fdim.ll | 32 ++++ enzyme/test/Enzyme/ForwardMode/frexp.ll | 27 ++- enzyme/test/Enzyme/ForwardMode/modf.ll | 120 +++++++++++++ enzyme/test/Enzyme/ReverseMode/acosh.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/asinh.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/asinhf.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/asinhl.ll | 46 +++++ enzyme/test/Enzyme/ReverseMode/atanh.ll | 45 +++++ enzyme/test/Enzyme/ReverseMode/coshf.ll | 29 +++ enzyme/test/Enzyme/ReverseMode/coshl.ll | 29 +++ enzyme/test/Enzyme/ReverseMode/fdim.ll | 54 ++++++ enzyme/test/Enzyme/ReverseMode/frexp.ll | 28 ++- enzyme/test/Enzyme/ReverseMode/modf.ll | 189 ++++++++++++++++++++ 23 files changed, 1019 insertions(+), 9 deletions(-) create mode 100644 enzyme/test/Enzyme/ForwardMode/acosh.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/asinh.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/asinhf.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/asinhl.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/atanh.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/coshf.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/coshl.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/fdim.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/modf.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/acosh.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/asinh.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/asinhf.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/asinhl.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/atanh.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/coshf.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/coshl.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/fdim.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/modf.ll diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 79a56d885507..3175e277a397 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -900,7 +900,8 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, if (isMemFreeLibMFunction(funcName)) { noActiveWrite = true; } else if (funcName == "frexp" || funcName == "frexpf" || - funcName == "frexpl") { + funcName == "frexpl" || funcName == "modf" || + funcName == "modff" || funcName == "modfl") { noActiveWrite = true; } } diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 58731be4a532..77a67fafe325 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -210,7 +210,9 @@ def AssertingInactiveArg : InactiveArgSpec { class GlobalExpr : Operation{ string value = val; } -def MantissaMaskOfReturn : GlobalExprisX86_FP80Ty()) {\n" " tsize = 80;\n" " high = tsize - 1;\n" -" low = high - 15;\n" +" low = high - 16;\n" + // x86_fp80 has only 15 exponent bits, but we must also + // retain the most-significant bit of the mantissa as + // there is no implicit leading 1. " } else if (ty->isFP128Ty()) {\n" " tsize = 128;\n" " high = tsize - 1;\n" @@ -317,13 +322,18 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; - def : CallPattern<(Op $x), ["tanhf"], [(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"coshf">), [ReadNone,NoUnwind]> $x):$c, $c))], (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["tanhl"], + [(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"coshl">), [ReadNone,NoUnwind]> $x):$c, $c))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["cosh"], @@ -337,6 +347,12 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["coshl"], + [(FMul (DiffeRet), (Call<(SameTypesFunc<"sinhl">), [ReadNone,NoUnwind]> $x))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["sinh"], @@ -350,6 +366,33 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["sinhl"], + [(FMul (DiffeRet), (Call<(SameTypesFunc<"coshl">), [ReadNone,NoUnwind]> $x))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["asinh", "asinhf", "asinhl", "__nv_asinh", "__nv_asinhf"], + [(FDiv (DiffeRet), (Intrinsic<"sqrt"> (FAdd (FMul $x, $x), (ConstantFP<"1.0"> $x))) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["acosh", "acoshf", "acoshl", "__nv_acosh", "__nv_acoshf"], + [(FDiv (DiffeRet), (Intrinsic<"sqrt"> (FSub (FMul $x, $x), (ConstantFP<"1.0"> $x))) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["atanh", "atanhf", "atanhl", "__nv_atanh", "__nv_atanhf"], + [(FDiv (DiffeRet), (FSub (ConstantFP<"1.0"> $x), (FMul $x, $x)) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["exp10"], @@ -443,6 +486,16 @@ def : CallPattern<(Op $x, $y), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x, $integral_part_ptr), + ["modf", "modff", "modfl"], + [ + (DiffeRet), + (InactiveArg) + ], + (ForwardFromSummedReverse), + [ReadOnly, NoUnwind] + >; + def : CallPattern<(Op $x), ["__fd_sincos_1", "__fd_sincos_1f", "__fd_sincos_1l"], [ @@ -564,7 +617,7 @@ def : CallPattern<(Op $n, $x), >; def : CallPattern<(Op $x), - ["erf"], + ["erf","erff","erfl"], [ (FMul (DiffeRet), (FMul (ConstantFP<"1.1283791670955125738961589031215451716881012586580"> $x), (Intrinsic<"exp"> (FNeg (FMul $x, $x))))) ], @@ -580,7 +633,7 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x), - ["erfc"], + ["erfc","erfcf","erfcl"], [ (FMul (DiffeRet), (FMul (ConstantFP<"-1.1283791670955125738961589031215451716881012586580"> $x), (Intrinsic<"exp"> (FNeg (FMul $x, $x))))) ], @@ -722,7 +775,7 @@ def : CallPattern<(Op $x, $expout), (DiffeRet), (FMul (BitCast - (And (MantissaMaskOfReturn):$mask, (BitCast $x, (TypeOf $mask)) ), + (And (MantissaMaskOfReturnForFrexp):$mask, (BitCast $x, (TypeOf $mask)) ), (TypeOf $x) ), (ConstantFP<"2"> $x) @@ -839,6 +892,16 @@ def : IntrPattern<(Op $x, $y), (Select (FCmpOLT $x, $y), (SelectIfActive $y, (Shadow $y), (Zero $y)), (SelectIfActive $x, (Shadow $x), (Zero $x))) >; +def : CallPattern<(Op $x, $y), + ["fdim", "fdimf", "fdiml"], + [ + (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $x), (DiffeRet)), + (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $y), (FNeg (DiffeRet))) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : IntrPattern<(Op $x), [["fabs"]], [ diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 9ff5ba6384aa..b7b900783f48 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -175,6 +175,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"trunc", Intrinsic::trunc}, {"round", Intrinsic::round}, {"rint", Intrinsic::rint}, + {"nearbyint", Intrinsic::nearbyint}, {"remainder", Intrinsic::not_intrinsic}, {"copysign", Intrinsic::copysign}, {"nextafter", Intrinsic::not_intrinsic}, @@ -5226,6 +5227,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { CONSIDER(frexpl) CONSIDER2(ldexp, double, double, int) CONSIDER2(modf, double, double, double *) + CONSIDER(modff) + CONSIDER(modfl) CONSIDER2(remquo, double, double, double, int *) CONSIDER(remquof) diff --git a/enzyme/test/Enzyme/ForwardMode/acosh.ll b/enzyme/test/Enzyme/ForwardMode/acosh.ll new file mode 100644 index 000000000000..90108c74de17 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/acosh.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @acosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @acosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fsub fast double %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %1) +; CHECK-NEXT: %3 = fdiv fast double %"x'", %2 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinh.ll b/enzyme/test/Enzyme/ForwardMode/asinh.ll new file mode 100644 index 000000000000..73377c6edfd5 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinh.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @asinh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @asinh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fadd fast double %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %1) +; CHECK-NEXT: %3 = fdiv fast double %"x'", %2 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinhf.ll b/enzyme/test/Enzyme/ForwardMode/asinhf.ll new file mode 100644 index 000000000000..edd9959120d9 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinhf.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @asinhf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_fwddiff(float (float)* nonnull @tester, float %x, float 1.0) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @asinhf(float) + +; Function Attrs: nounwind +declare float @__enzyme_fwddiff(float (float)*, ...) + +; CHECK: define internal float @fwddiffetester(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast float %x, %x +; CHECK-NEXT: %1 = fadd fast float %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast float @llvm.sqrt.f32(float %1) +; CHECK-NEXT: %3 = fdiv fast float %"x'", %2 +; CHECK-NEXT: ret float %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinhl.ll b/enzyme/test/Enzyme/ForwardMode/asinhl.ll new file mode 100644 index 000000000000..80929fd665c2 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinhl.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @asinhl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_fwddiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x, x86_fp80 0xK3FFF8000000000000000) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @asinhl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_fwddiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal x86_fp80 @fwddiffetester(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast x86_fp80 %x, %x +; CHECK-NEXT: %1 = fadd fast x86_fp80 %0, 0xK3FFF8000000000000000 +; CHECK-NEXT: %2 = call fast x86_fp80 @llvm.sqrt.f80(x86_fp80 %1) +; CHECK-NEXT: %3 = fdiv fast x86_fp80 %"x'", %2 +; CHECK-NEXT: ret x86_fp80 %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/atanh.ll b/enzyme/test/Enzyme/ForwardMode/atanh.ll new file mode 100644 index 000000000000..358ab107e140 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/atanh.ll @@ -0,0 +1,30 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @atanh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @atanh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fsub fast double 1.000000e+00, %0 +; CHECK-NEXT: %2 = fdiv fast double %"x'", %1 +; CHECK-NEXT: ret double %2 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/coshf.ll b/enzyme/test/Enzyme/ForwardMode/coshf.ll new file mode 100644 index 000000000000..b94270380e52 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/coshf.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @coshf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_fwddiff(float (float)* nonnull @tester, float %x, float 1.0) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @coshf(float) + +; Function Attrs: nounwind +declare float @__enzyme_fwddiff(float (float)*, ...) + +; CHECK: define internal float @fwddiffetester(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast float @sinhf(float %x) +; CHECK-NEXT: %1 = fmul fast float %"x'", %0 +; CHECK-NEXT: ret float %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/coshl.ll b/enzyme/test/Enzyme/ForwardMode/coshl.ll new file mode 100644 index 000000000000..4a6e3054f0fc --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/coshl.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @coshl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_fwddiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x, x86_fp80 0xK3FFF8000000000000000) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @coshl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_fwddiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal x86_fp80 @fwddiffetester(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast x86_fp80 @sinhl(x86_fp80 %x) +; CHECK-NEXT: %1 = fmul fast x86_fp80 %"x'", %0 +; CHECK-NEXT: ret x86_fp80 %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/fdim.ll b/enzyme/test/Enzyme/ForwardMode/fdim.ll new file mode 100644 index 000000000000..7d333686f7f4 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/fdim.ll @@ -0,0 +1,32 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s ; fi + +declare double @fdim(double, double) + +define double @tester(double %x, double %y) { +entry: + %0 = call double @fdim(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 10.0, double %y, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double, double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fcmp fast olt double %x, %y +; CHECK-NEXT: %1 = select fast i1 %0, double 0.000000e+00, double %"x'" +; CHECK-NEXT: %2 = fcmp fast olt double %x, %y +; CHECK-NEXT: %3 = fneg fast double %"y'" +; CHECK-NEXT: %4 = select fast i1 %2, double 0.000000e+00, double %3 +; CHECK-NEXT: %5 = fadd fast double %1, %4 +; CHECK-NEXT: ret double %5 +; CHECK-NEXT: } + + diff --git a/enzyme/test/Enzyme/ForwardMode/frexp.ll b/enzyme/test/Enzyme/ForwardMode/frexp.ll index 9e421ec21e3a..f9f0916a09b9 100644 --- a/enzyme/test/Enzyme/ForwardMode/frexp.ll +++ b/enzyme/test/Enzyme/ForwardMode/frexp.ll @@ -1,10 +1,11 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s -declare double @frexp(double, i32*) declare double @__enzyme_fwddiff(i8*, ...) declare float @__enzyme_fwddifff(i8*, ...) +declare x86_fp80 @__enzyme_fwddiffl(i8*, ...) +declare double @frexp(double, i32*) define double @test(double %x) { entry: %exp = alloca i32, align 4 @@ -32,6 +33,20 @@ entry: ret float %call } +declare x86_fp80 @frexpl(x86_fp80, i32*) +define x86_fp80 @testl(x86_fp80 %x) { +entry: + %exp = alloca i32, align 4 + %call = call x86_fp80 @frexpl(x86_fp80 %x, i32* %exp) + ret x86_fp80 %call +} + +define x86_fp80 @dtestl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} + ; CHECK: define internal double @fwddiffetest(double %x, double %"x'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = bitcast double %x to i64 @@ -51,3 +66,13 @@ entry: ; CHECK-NEXT: %4 = fdiv fast float %"x'", %3 ; CHECK-NEXT: ret float %4 ; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast x86_fp80 %x to i80 +; CHECK-NEXT: %1 = and i80 604453686435277732577280, %0 +; CHECK-NEXT: %2 = bitcast i80 %1 to x86_fp80 +; CHECK-NEXT: %3 = fmul fast x86_fp80 %2, 0xK40008000000000000000 +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %"x'", %3 +; CHECK-NEXT: ret x86_fp80 %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/modf.ll b/enzyme/test/Enzyme/ForwardMode/modf.ll new file mode 100644 index 000000000000..d36c4d659315 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/modf.ll @@ -0,0 +1,120 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +declare double @__enzyme_fwddiff(i8*, ...) +declare float @__enzyme_fwddifff(i8*, ...) +declare x86_fp80 @__enzyme_fwddiffl(i8*, ...) + +; double +declare double @modf(double, double*) +define double @testint(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + %ret = load double, double* %integral_part, align 8 + ret double %ret +} +define double @testfrac(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + ret double %fractional_part +} + +define double @dtestint(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @testint to i8*), double %x, double %dx) + ret double %call +} +define double @dtestfrac(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @testfrac to i8*), double %x, double %dx) + ret double %call +} + +; float +declare float @modff(float, float*) +define float @testintf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + %ret = load float, float* %integral_part, align 4 + ret float %ret +} +define float @testfracf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + ret float %fractional_part +} + +define float @dtestintf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_fwddifff(i8* bitcast (float (float)* @testintf to i8*), float %x, float %dx) + ret float %call +} +define float @dtestfracf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_fwddifff(i8* bitcast (float (float)* @testfracf to i8*), float %x, float %dx) + ret float %call +} + +; x86_fp80 +declare x86_fp80 @modfl(x86_fp80, x86_fp80*) +define x86_fp80 @testintl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + %ret = load x86_fp80, x86_fp80* %integral_part, align 8 + ret x86_fp80 %ret +} +define x86_fp80 @testfracl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + ret x86_fp80 %fractional_part +} + +define x86_fp80 @dtestintl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testintl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} +define x86_fp80 @dtestfracl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testfracl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} + +; tests + +; CHECK: define internal double @fwddiffetestint(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double 0.000000e+00 +; CHECK-NEXT: } + +; CHECK: define internal double @fwddiffetestfrac(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } + +; CHECK: define internal float @fwddiffetestintf(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float 0.000000e+00 +; CHECK-NEXT: } + +; CHECK: define internal float @fwddiffetestfracf(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float %"x'" +; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestintl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret x86_fp80 0xK00000000000000000000 +; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestfracl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret x86_fp80 %"x'" +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/acosh.ll b/enzyme/test/Enzyme/ReverseMode/acosh.ll new file mode 100644 index 000000000000..b6ff83ea4f5f --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/acosh.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @acosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @acosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fsub fast double %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast double @llvm.sqrt.f64(double %2) +; CHECK-NEXT: %4 = fdiv fast double %0, %3 +; CHECK-NEXT: %5 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fadd fast double %5, %4 +; CHECK-NEXT: store double %6, double* %"x'de", align 8 +; CHECK-NEXT: %7 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %8 = insertvalue { double } undef, double %7, 0 +; CHECK-NEXT: ret { double } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinh.ll b/enzyme/test/Enzyme/ReverseMode/asinh.ll new file mode 100644 index 000000000000..3c34fabc25bf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinh.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @asinh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @asinh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fadd fast double %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast double @llvm.sqrt.f64(double %2) +; CHECK-NEXT: %4 = fdiv fast double %0, %3 +; CHECK-NEXT: %5 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fadd fast double %5, %4 +; CHECK-NEXT: store double %6, double* %"x'de", align 8 +; CHECK-NEXT: %7 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %8 = insertvalue { double } undef, double %7, 0 +; CHECK-NEXT: ret { double } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinhf.ll b/enzyme/test/Enzyme/ReverseMode/asinhf.ll new file mode 100644 index 000000000000..c113c111abc1 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinhf.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @asinhf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* nonnull @tester, float %x) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @asinhf(float) + +; Function Attrs: nounwind +declare float @__enzyme_autodiff(float (float)*, ...) + +; CHECK: define internal { float } @diffetester(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"'de", align 4 +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store float %differeturn, float* %"'de", align 4 +; CHECK-NEXT: %0 = load float, float* %"'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"'de", align 4 +; CHECK-NEXT: %1 = fmul fast float %x, %x +; CHECK-NEXT: %2 = fadd fast float %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast float @llvm.sqrt.f32(float %2) +; CHECK-NEXT: %4 = fdiv fast float %0, %3 +; CHECK-NEXT: %5 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %6 = fadd fast float %5, %4 +; CHECK-NEXT: store float %6, float* %"x'de", align 4 +; CHECK-NEXT: %7 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %8 = insertvalue { float } undef, float %7, 0 +; CHECK-NEXT: ret { float } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinhl.ll b/enzyme/test/Enzyme/ReverseMode/asinhl.ll new file mode 100644 index 000000000000..759d05bf3236 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinhl.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @asinhl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_autodiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @asinhl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_autodiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal { x86_fp80 } @diffetester(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store x86_fp80 %differeturn, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"'de", align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %1 = fmul fast x86_fp80 %x, %x +; CHECK-NEXT: %2 = fadd fast x86_fp80 %1, 0xK3FFF8000000000000000 +; CHECK-NEXT: %3 = call fast x86_fp80 @llvm.sqrt.f80(x86_fp80 %2) +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %0, %3 +; CHECK-NEXT: %5 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %6 = fadd fast x86_fp80 %5, %4 +; CHECK-NEXT: store x86_fp80 %6, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %7 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %8 = insertvalue { x86_fp80 } undef, x86_fp80 %7, 0 +; CHECK-NEXT: ret { x86_fp80 } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/atanh.ll b/enzyme/test/Enzyme/ReverseMode/atanh.ll new file mode 100644 index 000000000000..708c4da03bf3 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/atanh.ll @@ -0,0 +1,45 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @atanh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @atanh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fsub fast double 1.000000e+00, %1 +; CHECK-NEXT: %3 = fdiv fast double %0, %2 +; CHECK-NEXT: %4 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %5 = fadd fast double %4, %3 +; CHECK-NEXT: store double %5, double* %"x'de", align 8 +; CHECK-NEXT: %6 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0 +; CHECK-NEXT: ret { double } %7 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/coshf.ll b/enzyme/test/Enzyme/ReverseMode/coshf.ll new file mode 100644 index 000000000000..b028239a4ecf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/coshf.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @coshf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* nonnull @tester, float %x) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @coshf(float) + +; Function Attrs: nounwind +declare float @__enzyme_autodiff(float (float)*, ...) + +; CHECK: define internal { float } @diffetester(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast float @sinhf(float %x) +; CHECK-NEXT: %1 = fmul fast float %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { float } undef, float %1, 0 +; CHECK-NEXT: ret { float } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/coshl.ll b/enzyme/test/Enzyme/ReverseMode/coshl.ll new file mode 100644 index 000000000000..dd8af143ff48 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/coshl.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @coshl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_autodiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @coshl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_autodiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal { x86_fp80 } @diffetester(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast x86_fp80 @sinhl(x86_fp80 %x) +; CHECK-NEXT: %1 = fmul fast x86_fp80 %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { x86_fp80 } undef, x86_fp80 %1, 0 +; CHECK-NEXT: ret { x86_fp80 } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/fdim.ll b/enzyme/test/Enzyme/ReverseMode/fdim.ll new file mode 100644 index 000000000000..113532464547 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/fdim.ll @@ -0,0 +1,54 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s ; fi + +declare double @fdim(double, double) + +define double @tester(double %x, double %y) { +entry: + %0 = call double @fdim(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, double %x, double %y) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double, double)*, ...) + +; CHECK: define internal { double, double } @diffetester(double %x, double %y, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: %"y'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"y'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fcmp fast olt double %x, %y +; CHECK-NEXT: %2 = select fast i1 %1, double 0.000000e+00, double %0 +; CHECK-NEXT: %3 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %4 = fadd fast double %3, %0 +; CHECK-NEXT: %5 = select fast i1 %1, double %3, double %4 +; CHECK-NEXT: store double %5, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fcmp fast olt double %x, %y +; CHECK-NEXT: %7 = fneg fast double %0 +; CHECK-NEXT: %8 = select fast i1 %6, double 0.000000e+00, double %7 +; CHECK-NEXT: %9 = load double, double* %"y'de", align 8 +; CHECK-NEXT: %10 = fadd fast double %9, %7 +; CHECK-NEXT: %11 = select fast i1 %6, double %9, double %10 +; CHECK-NEXT: store double %11, double* %"y'de", align 8 +; CHECK-NEXT: %12 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %13 = load double, double* %"y'de", align 8 +; CHECK-NEXT: %14 = insertvalue { double, double } undef, double %12, 0 +; CHECK-NEXT: %15 = insertvalue { double, double } %14, double %13, 1 +; CHECK-NEXT: ret { double, double } %15 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/frexp.ll b/enzyme/test/Enzyme/ReverseMode/frexp.ll index 2e74d6914105..fe134bdab00b 100644 --- a/enzyme/test/Enzyme/ReverseMode/frexp.ll +++ b/enzyme/test/Enzyme/ReverseMode/frexp.ll @@ -1,10 +1,11 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -adce -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify,adce)" -S | FileCheck %s -declare double @frexp(double, i32*) declare double @__enzyme_autodiff(i8*, ...) declare float @__enzyme_autodifff(i8*, ...) +declare x86_fp80 @__enzyme_autodiffl(i8*, ...) +declare double @frexp(double, i32*) define double @test(double %x) { entry: %exp = alloca i32, align 4 @@ -32,6 +33,20 @@ entry: ret float %call } +declare x86_fp80 @frexpl(x86_fp80, i32*) +define x86_fp80 @testl(x86_fp80 %x) { +entry: + %exp = alloca i32, align 4 + %call = call x86_fp80 @frexpl(x86_fp80 %x, i32* %exp) + ret x86_fp80 %call +} + +define x86_fp80 @dtestl(x86_fp80 %x) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} + ; CHECK: define internal { double } @diffetest(double %x, double %differeturn) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = bitcast double %x to i64 @@ -53,3 +68,14 @@ entry: ; CHECK-NEXT: %5 = insertvalue { float } {{(undef|poison)}}, float %4, 0 ; CHECK-NEXT: ret { float } %5 ; CHECK-NEXT: } + +; CHECK: define internal { x86_fp80 } @diffetestl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast x86_fp80 %x to i80 +; CHECK-NEXT: %1 = and i80 604453686435277732577280, %0 +; CHECK-NEXT: %2 = bitcast i80 %1 to x86_fp80 +; CHECK-NEXT: %3 = fmul fast x86_fp80 %2, 0xK40008000000000000000 +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %differeturn, %3 +; CHECK-NEXT: %5 = insertvalue { x86_fp80 } {{(undef|poison)}}, x86_fp80 %4, 0 +; CHECK-NEXT: ret { x86_fp80 } %5 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/modf.ll b/enzyme/test/Enzyme/ReverseMode/modf.ll new file mode 100644 index 000000000000..5c601a39e340 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/modf.ll @@ -0,0 +1,189 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +declare double @__enzyme_autodiff(i8*, ...) +declare float @__enzyme_autodifff(i8*, ...) +declare x86_fp80 @__enzyme_autodiffl(i8*, ...) + +; double +declare double @modf(double, double*) +define double @testint(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + %ret = load double, double* %integral_part, align 8 + ret double %ret +} +define double @testfrac(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + ret double %fractional_part +} + +define double @dtestint(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double)* @testint to i8*), double %x) + ret double %call +} +define double @dtestfrac(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double)* @testfrac to i8*), double %x) + ret double %call +} + +; float +declare float @modff(float, float*) +define float @testintf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + %ret = load float, float* %integral_part, align 4 + ret float %ret +} +define float @testfracf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + ret float %fractional_part +} + +define float @dtestintf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_autodifff(i8* bitcast (float (float)* @testintf to i8*), float %x) + ret float %call +} +define float @dtestfracf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_autodifff(i8* bitcast (float (float)* @testfracf to i8*), float %x) + ret float %call +} + +; x86_fp80 +declare x86_fp80 @modfl(x86_fp80, x86_fp80*) +define x86_fp80 @testintl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + %ret = load x86_fp80, x86_fp80* %integral_part, align 8 + ret x86_fp80 %ret +} +define x86_fp80 @testfracl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + ret x86_fp80 %fractional_part +} + +define x86_fp80 @dtestintl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testintl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} +define x86_fp80 @dtestfracl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testfracl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} + +; double tests + +; CHECK: define internal { double } @diffetestint(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %1 = insertvalue { double } undef, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffetestfrac(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"fractional_part'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %1 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %0 +; CHECK-NEXT: store double %2, double* %"x'de", align 8 +; CHECK-NEXT: %3 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %4 = insertvalue { double } undef, double %3, 0 +; CHECK-NEXT: ret { double } %4 +; CHECK-NEXT: } + +; float tests + +; CHECK: define internal { float } @diffetestintf(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %1 = insertvalue { float } undef, float %0, 0 +; CHECK-NEXT: ret { float } %1 +; CHECK-NEXT: } + +; CHECK: define internal { float } @diffetestfracf(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store float %differeturn, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %0 = load float, float* %"fractional_part'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %1 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %2 = fadd fast float %1, %0 +; CHECK-NEXT: store float %2, float* %"x'de", align 4 +; CHECK-NEXT: %3 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %4 = insertvalue { float } undef, float %3, 0 +; CHECK-NEXT: ret { float } %4 +; CHECK-NEXT: } + +; x86_fp80 tests + +; CHECK: define internal { x86_fp80 } @diffetestintl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %1 = insertvalue { x86_fp80 } undef, x86_fp80 %0, 0 +; CHECK-NEXT: ret { x86_fp80 } %1 +; CHECK-NEXT: } + +; CHECK: define internal { x86_fp80 } @diffetestfracl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store x86_fp80 %differeturn, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %1 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %2 = fadd fast x86_fp80 %1, %0 +; CHECK-NEXT: store x86_fp80 %2, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %3 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %4 = insertvalue { x86_fp80 } undef, x86_fp80 %3, 0 +; CHECK-NEXT: ret { x86_fp80 } %4 +; CHECK-NEXT: } From 55d67c3fad170fabef9d83d10ba534a89fd69d3d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 13:36:28 -0500 Subject: [PATCH 049/106] Act conservative if not affine (#1712) --- enzyme/Enzyme/FunctionUtils.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 6bc443276b4b..160ccc740f8a 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -7197,11 +7197,13 @@ getSparseConditions(bool &legal, Value *val, } } if (scope) - EmitFailure("NoSparsification", I->getDebugLoc(), I, - "F: ", *I->getParent()->getParent(), "\n", + EmitWarning("NoSparsification", *I, " No sparsification: not sparse solvable(icmp): ", *I, " via ", *sub1); - legal = false; + if (SparseDebug) { + llvm::errs() << " getSparse(icmp_dflt, " << *I + << ") = " << *defaultFloat << "\n"; + } return defaultFloat; } From c0c1070533767e11c3e97e3337724f309c60331f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 14:01:19 -0500 Subject: [PATCH 050/106] Change compile time or type analysis err into runtime (#1713) --- enzyme/Enzyme/InstructionDerivatives.td | 9 +++++++ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 28 ++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 77a67fafe325..32b30bd8f9b4 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -823,6 +823,15 @@ def : CallPattern<(Op (Op $x, $y):$z), [ReadNone, NoUnwind] >; +def : CallPattern<(Op (Op $x, $y):$z), + ["cmplx_inv"], + [ + (CFDiv (CFNeg (DiffeRet)), (CFMul $z, $z)), + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : IntrPattern<(Op $x), [["sin"]], [(FMul (DiffeRet), (Intrinsic<"cos"> $x))] , diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index b7b900783f48..851db780723d 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -157,6 +157,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"__fd_sincos_1", Intrinsic::not_intrinsic}, {"sincospi", Intrinsic::not_intrinsic}, + {"cmplx_inv", Intrinsic::not_intrinsic}, // bessel functions {"j0", Intrinsic::not_intrinsic}, @@ -2937,7 +2938,32 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, // If ^ against 0b10000000000, the result is a float bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl); if (validXor) { - ((i == 0) ? RHS : LHS) |= TypeTree(FT).Only(-1, nullptr); + bool Legal = true; + ((i == 0) ? RHS : LHS) + .checkedOrIn(TypeTree(FT).Only(-1, nullptr), + /*pointerintsame*/ false, Legal); + + if (!Legal) { + std::string str; + raw_string_ostream ss(str); + if (!CustomErrorHandler) { + llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; + llvm::errs() << *fntypeinfo.Function << "\n"; + dump(ss); + } + ss << "Illegal updateBinop (xor up) Analysis " << *origin << "\n"; + ss << " (i=" << i << ") " << (i == 0 ? "RHS" : "LHS") << " " + << ((i == 0) ? RHS : LHS).str() << " FT from ret: " << *FT + << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(str.c_str(), wrap(origin), + ErrorType::IllegalTypeAnalysis, (void *)this, + wrap(origin), nullptr); + } + EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(), + origin, ss.str()); + report_fatal_error("Performed illegal updateAnalysis"); + } } } break; From adfc693a187438f58c010ef9c12c04909f1e475e Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 21:00:48 -0500 Subject: [PATCH 051/106] MLIR Make core mlir registration fn (#1715) --- enzyme/Enzyme/ActivityAnalysis.cpp | 6 ++++++ .../CoreDialectsAutoDiffImplementations.cpp | 14 ++++++++++++++ .../CoreDialectsAutoDiffImplementations.h | 2 ++ enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp | 2 ++ enzyme/Enzyme/MLIR/enzymemlir-opt.cpp | 11 +---------- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 3175e277a397..243e71dec04c 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1432,6 +1432,12 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { if (containsPointer && !isValuePotentiallyUsedAsPointer(Val)) { containsPointer = false; + if (auto Arg = dyn_cast(Val)) { + assert(Arg->hasByValAttr()); + (void)Arg; + InsertConstantValue(TR, Val); + return true; + } } // We do this pointer dance here to ensure that any derived pointers from diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 857784472211..3bf48fc16b81 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -326,3 +326,17 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( return success(); } + +void mlir::enzyme::registerCoreDialectAutodiffInterfaces( + DialectRegistry ®istry) { + enzyme::registerAffineDialectAutoDiffInterface(registry); + enzyme::registerArithDialectAutoDiffInterface(registry); + enzyme::registerBuiltinDialectAutoDiffInterface(registry); + enzyme::registerLLVMDialectAutoDiffInterface(registry); + enzyme::registerNVVMDialectAutoDiffInterface(registry); + enzyme::registerMathDialectAutoDiffInterface(registry); + enzyme::registerMemRefDialectAutoDiffInterface(registry); + enzyme::registerSCFDialectAutoDiffInterface(registry); + enzyme::registerCFDialectAutoDiffInterface(registry); + enzyme::registerLinalgDialectAutoDiffInterface(registry); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 5a1fbc136708..974b888599d5 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -187,5 +187,7 @@ void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); + +void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 2b2760dfa8ff..ba953dbca5df 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -238,6 +238,8 @@ FunctionOpInterface CloneFunctionWithReturns( { auto &blk = NewF.getFunctionBody().front(); + assert(F.getFunctionBody().front().getNumArguments() == + constant_args.size()); for (ssize_t i = constant_args.size() - 1; i >= 0; i--) { mlir::Value oval = F.getFunctionBody().front().getArgument(i); if (constant_args[i] == DIFFE_TYPE::CONSTANT) diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 90038c1e3236..4a7b9231d1ee 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -97,16 +97,7 @@ int main(int argc, char **argv) { }); // Register the autodiff interface implementations for upstream dialects. - enzyme::registerAffineDialectAutoDiffInterface(registry); - enzyme::registerArithDialectAutoDiffInterface(registry); - enzyme::registerBuiltinDialectAutoDiffInterface(registry); - enzyme::registerLLVMDialectAutoDiffInterface(registry); - enzyme::registerNVVMDialectAutoDiffInterface(registry); - enzyme::registerMathDialectAutoDiffInterface(registry); - enzyme::registerMemRefDialectAutoDiffInterface(registry); - enzyme::registerSCFDialectAutoDiffInterface(registry); - enzyme::registerCFDialectAutoDiffInterface(registry); - enzyme::registerLinalgDialectAutoDiffInterface(registry); + enzyme::registerCoreDialectAutodiffInterfaces(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Enzyme modular optimizer driver", registry)); From 9384fe20caec02bd30f302e32f4f1c1f7ccb7d9d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 13 Feb 2024 23:08:18 -0500 Subject: [PATCH 052/106] Fix enzymemlir visibility and return (#1716) --- enzyme/BUILD | 3 ++- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index 0ad691e5aae0..9bf6fadabe9a 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -7,7 +7,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//visibility:public"], ) cc_library( @@ -29,6 +29,7 @@ cc_binary( "@llvm-project//llvm:TableGen", "@llvm-project//llvm:config", ], + visibility = ["//visibility:public"], ) gentbl( diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index d96c18a68c09..a5c4d45ced95 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -86,11 +86,14 @@ struct DifferentiateWrapperPass FunctionOpInterface newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, - /*should return*/ false, mode, freeMemory, width, + /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, + width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); if (outfn == "") { fn->erase(); + SymbolTable::setSymbolName(cast(newFunc), + (std::string)infn); } else { SymbolTable::setSymbolName(cast(newFunc), (std::string)outfn); From 3dcc55a4ca93a95e269433d0ada1829c4cb55ce7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 14 Feb 2024 20:27:28 -0500 Subject: [PATCH 053/106] MLIR fix broken test (#1718) --- enzyme/test/MLIR/ForwardMode/wrap.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir index 7cef99c691f8..5ff0e1540d13 100644 --- a/enzyme/test/MLIR/ForwardMode/wrap.mlir +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -7,10 +7,10 @@ module { } } -// CHECK: func.func private @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK: func.func private @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { // CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 // CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 // CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 -// CHECK-NEXT: return %[[i2]] : f64 +// CHECK-NEXT: return %[[i3]], %[[i2]] : f64, f64 // CHECK-NEXT: } From eaa72381f5134f4f0f43cbc75fdbb2abfe4f9603 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 14 Feb 2024 20:36:16 -0500 Subject: [PATCH 054/106] MLIR change visibility for wrapper pass (#1719) --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index a5c4d45ced95..327c3d254911 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -92,6 +92,7 @@ struct DifferentiateWrapperPass /*augmented*/ nullptr); if (outfn == "") { fn->erase(); + SymbolTable::setSymbolVisibility(newFunc, SymbolTable::Visibility::Private); SymbolTable::setSymbolName(cast(newFunc), (std::string)infn); } else { From b74b7e9e8f67198c2b96724a1d042714ff6b2277 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 14 Feb 2024 20:45:28 -0500 Subject: [PATCH 055/106] MLIR embarassing bugfix (#1720) --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 327c3d254911..e50306b4a32b 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -92,7 +92,8 @@ struct DifferentiateWrapperPass /*augmented*/ nullptr); if (outfn == "") { fn->erase(); - SymbolTable::setSymbolVisibility(newFunc, SymbolTable::Visibility::Private); + SymbolTable::setSymbolVisibility(newFunc, + SymbolTable::Visibility::Public); SymbolTable::setSymbolName(cast(newFunc), (std::string)infn); } else { From 742c8d035c7920f0c0ca6f9a72146b8f886a8228 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Wed, 14 Feb 2024 22:26:54 -0800 Subject: [PATCH 056/106] Add full module truncation mode (#1717) * Add full module truncation mode * Fix optional header * Forgot test * Apply suggestions from code review Co-authored-by: Tim Gymnich * Change to command line option * Fix comp --------- Co-authored-by: Tim Gymnich --- enzyme/Enzyme/Enzyme.cpp | 86 ++++++++++++++++++- enzyme/Enzyme/EnzymeLogic.cpp | 28 +++--- enzyme/Enzyme/EnzymeLogic.h | 2 +- enzyme/test/Integration/Truncate/simple.cpp | 5 +- .../Integration/Truncate/truncate-all.cpp | 45 ++++++++++ 5 files changed, 147 insertions(+), 19 deletions(-) create mode 100644 enzyme/test/Integration/Truncate/truncate-all.cpp diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 99c429efa08e..055b6f394842 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -37,10 +37,9 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" +#include #if LLVM_VERSION_MAJOR <= 16 #include "llvm/ADT/Optional.h" -#else -#include #endif #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -113,6 +112,12 @@ llvm::cl::opt EnzymeAttributor("enzyme-attributor", cl::init(false), llvm::cl::opt EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, cl::desc("Whether to enable openmp opt")); +llvm::cl::opt EnzymeTruncateAll( + "enzyme-truncate-all", cl::init(""), cl::Hidden, + cl::desc( + "Truncate all floating point operations. " + "E.g. \"64to32\" or \"64to-\".")); + #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex #define getAttribute getAttributeAtIndex @@ -2044,6 +2049,76 @@ class EnzymeBase { return status; } + bool handleFullModuleTrunc(Function &F) { + typedef std::vector> + TruncationsTy; + static TruncationsTy FullModuleTruncs = []() -> TruncationsTy { + StringRef ConfigStr(EnzymeTruncateAll); + auto Invalid = [=]() { + // TODO emit better diagnostic + llvm::errs() << "error: invalid format for truncation config\n"; + abort(); + }; + + // "64" or "11-52" + auto parseFloatRepr = [&]() -> std::optional { + unsigned Tmp = 0; + if (ConfigStr.consumeInteger(10, Tmp)) + return {}; + if (ConfigStr.consume_front("-")) { + unsigned Tmp2 = 0; + if (ConfigStr.consumeInteger(10, Tmp2)) + Invalid(); + return FloatRepresentation(Tmp, Tmp2); + } + return getDefaultFloatRepr(Tmp); + }; + + // Parse "64to32;32to16;5-10to4-9" + TruncationsTy Tmp; + while (true) { + auto From = parseFloatRepr(); + if (!From && !ConfigStr.empty()) + Invalid(); + if (!From) + break; + if (!ConfigStr.consume_front("to")) + Invalid(); + auto To = parseFloatRepr(); + if (!To) + Invalid(); + Tmp.push_back({*From, *To}); + ConfigStr.consume_front(";"); + } + return Tmp; + }(); + + if (FullModuleTruncs.empty()) + return false; + + // TODO sort truncations (64to32, then 32to16 will make everything 16) + for (auto Truncation : FullModuleTruncs) { + IRBuilder<> Builder(F.getContext()); + RequestContext context(&*F.getEntryBlock().begin(), &Builder); + Function *TruncatedFunc = + Logic.CreateTruncateFunc(context, &F, Truncation.first, + Truncation.second, TruncOpFullModuleMode); + + ValueToValueMapTy Mapping; + for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args())) + Mapping[&TArg] = &Arg; + + // Move the truncated body into the original function + F.deleteBody(); + F.getBasicBlockList().splice(F.begin(), + TruncatedFunc->getBasicBlockList()); + RemapFunction(F, Mapping, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + TruncatedFunc->deleteBody(); + } + return true; + } + bool lowerEnzymeCalls(Function &F, std::set &done) { if (done.count(&F)) return false; @@ -2052,6 +2127,9 @@ class EnzymeBase { if (F.empty()) return false; + if (handleFullModuleTrunc(F)) + return true; + bool Changed = false; for (BasicBlock &BB : F) @@ -2629,10 +2707,10 @@ class EnzymeBase { HandleBatch(call); } for (auto call : toTruncateFuncMem) { - HandleTruncateFunc(call, TruncMem); + HandleTruncateFunc(call, TruncMemMode); } for (auto call : toTruncateFuncOp) { - HandleTruncateFunc(call, TruncOp); + HandleTruncateFunc(call, TruncOpMode); } for (auto call : toTruncateValue) { HandleTruncateValue(call, true); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index ebdbb937733d..731826bb793f 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4969,7 +4969,7 @@ class TruncateGenerator : public llvm::InstVisitor { fromType = from.getBuiltinType(B.getContext()); toType = to.getType(B.getContext()); - if (mode == TruncMem) + if (mode == TruncMemMode) tmpBlock = B.CreateAlloca(fromType); else tmpBlock = nullptr; @@ -5002,9 +5002,10 @@ class TruncateGenerator : public llvm::InstVisitor { Value *truncate(IRBuilder<> &B, Value *v) { switch (mode) { - case TruncMem: + case TruncMemMode: return floatMemTruncate(B, v, tmpBlock, from, to); - case TruncOp: + case TruncOpMode: + case TruncOpFullModuleMode: return floatValTruncate(B, v, tmpBlock, from, to); default: llvm_unreachable("Unknown trunc mode"); @@ -5013,9 +5014,10 @@ class TruncateGenerator : public llvm::InstVisitor { Value *expand(IRBuilder<> &B, Value *v) { switch (mode) { - case TruncMem: + case TruncMemMode: return floatMemExpand(B, v, tmpBlock, from, to); - case TruncOp: + case TruncOpMode: + case TruncOpFullModuleMode: return floatValExpand(B, v, tmpBlock, from, to); default: llvm_unreachable("Unknown trunc mode"); @@ -5088,7 +5090,7 @@ class TruncateGenerator : public llvm::InstVisitor { } void visitSelectInst(llvm::SelectInst &SI) { switch (mode) { - case TruncMem: { + case TruncMemMode: { auto newI = getNewFromOriginal(&SI); IRBuilder<> B(newI); auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); @@ -5101,7 +5103,8 @@ class TruncateGenerator : public llvm::InstVisitor { newI->eraseFromParent(); return; } - case TruncOp: + case TruncOpMode: + case TruncOpFullModuleMode: return; default: llvm_unreachable(""); @@ -5347,9 +5350,11 @@ class TruncateGenerator : public llvm::InstVisitor { if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall)) return; - RequestContext ctx(&CI, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); - newCall->setCalledOperand(val); + if (mode != TruncOpFullModuleMode) { + RequestContext ctx(&CI, &BuilderZ); + auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); + newCall->setCalledOperand(val); + } return; } void visitFPTruncInst(FPTruncInst &I) { return; } @@ -5426,7 +5431,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); std::string truncName = std::string("__enzyme_done_truncate_") + - (mode == TruncMem ? "mem" : "op") + "_func_" + + (mode == TruncMemMode ? "mem" : "op") + "_func_" + from.to_string() + "_" + to.to_string() + "_" + totrunc->getName().str(); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, @@ -5463,6 +5468,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } + // TODO This is overloaded an doesnt do what it should do here if (from < to) { std::string s; llvm::raw_string_ostream ss(s); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 3923f205037e..1e9bf216b6e9 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -286,7 +286,7 @@ getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { } } -enum TruncateMode { TruncMem, TruncOp }; +enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; struct FloatRepresentation { // |_|__________|_________________| diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 792366d687e0..53cf859c84da 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -15,14 +15,13 @@ double simple_add(double a, double b) { double intrinsics(double a, double b) { return sqrt(a) * pow(b, 2); } -// TODO +// TODO trunc mem mode double constt(double a, double b) { return 2; } double compute(double *A, double *B, double *C, int n) { for (int i = 0; i < n; i++) { - C[i] = A[i] * 2; - // C[i] A=[i] * 2 + B[i] * sqrt(A[i]) ; + C[i] = A[i] * 2 + B[i] * sqrt(A[i]); } return C[0]; } diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp new file mode 100644 index 000000000000..ad5df438842f --- /dev/null +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -0,0 +1,45 @@ +// Baseline +// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli -)" == "900000000.560000" ] ; fi + +// Truncated +// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli -)" == "900000000.000000" ] ; fi +// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli -)" == "900000000.000000" ] ; fi + +#include + +#include "../test_utils.h" + +#define N 10 + +#define floatty double + + +__attribute__((noinline)) +floatty simple_add(floatty a, floatty b) { + return a + b; +} +__attribute__((noinline)) +floatty intrinsics(floatty a, floatty b) { + return sqrt(a) * pow(b, 2); +} +__attribute__((noinline)) +floatty compute(floatty *A, floatty *B, floatty *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] / 2 + intrinsics(A[i], simple_add(B[i] * 10000, 0.000001)); + } + return C[0]; +} + +int main() { + floatty A[N]; + floatty B[N]; + floatty C[N]; + + for (int i = 0; i < N; i++) { + A[i] = 1 + i % 5; + B[i] = 1 + i % 3; + } + + compute(A, B, C, N); + printf("%f\n", C[5]); +} From 1b0a46dd751f204e1bbef8cc0641e3a1ae27c74f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 12:44:58 -0500 Subject: [PATCH 057/106] MLIR correctly preserve attributes for shadows (#1721) * MLIR correctly preserve attributes for shadows * fix --- enzyme/Enzyme/MLIR/Implementations/Common.td | 24 ++-- .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 105 +++++++++++++++++- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 4 +- 3 files changed, 118 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 1451d99f17ca..70f4c39f99d6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -62,17 +62,17 @@ def Op { class ArithInst : Inst; class MathInst : Inst; -def AddF : ArithInst<"arith::AddFOp">; -def SubF : ArithInst<"arith::SubFOp">; -def NegF : ArithInst<"arith::NegFOp">; -def MulF : ArithInst<"arith::MulFOp">; -def DivF : ArithInst<"arith::DivFOp">; -def RemF : ArithInst<"arith::RemFOp">; +def AddF : ArithInst<"AddFOp">; +def SubF : ArithInst<"SubFOp">; +def NegF : ArithInst<"NegFOp">; +def MulF : ArithInst<"MulFOp">; +def DivF : ArithInst<"DivFOp">; +def RemF : ArithInst<"RemFOp">; -def CheckedMulF : ArithInst<"arith::MulFOp">; -def CheckedDivF : ArithInst<"arith::DivFOp">; +def CheckedMulF : ArithInst<"MulFOp">; +def CheckedDivF : ArithInst<"DivFOp">; -def CosF : MathInst<"math::CosOp">; -def SinF : MathInst<"math::SinOp">; -def ExpF : MathInst<"math::ExpOp">; -def SqrtF : MathInst<"math::SqrtOp">; +def CosF : MathInst<"CosOp">; +def SinF : MathInst<"SinOp">; +def ExpF : MathInst<"ExpOp">; +def SqrtF : MathInst<"SqrtOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index ba953dbca5df..7e802b6a952e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -1,3 +1,5 @@ +#include "llvm/ADT/APSInt.h" + #include "CloneFunction.h" using namespace mlir; @@ -203,7 +205,7 @@ FunctionOpInterface CloneFunctionWithReturns( SmallPtrSetImpl &constants, SmallPtrSetImpl &nonconstants, SmallPtrSetImpl &returnvals, ReturnType returnValue, - DIFFE_TYPE ReturnType, Twine name, IRMapping &VMap, + DIFFE_TYPE DReturnType, Twine name, IRMapping &VMap, std::map &OpMap, bool diffeReturnArg, mlir::Type additionalArg) { assert(!F.getFunctionBody().empty()); @@ -211,7 +213,7 @@ FunctionOpInterface CloneFunctionWithReturns( // llvm::ValueToValueMapTy VMap; auto FTy = getFunctionTypeForClone( F.getFunctionType().cast(), mode, width, - additionalArg, constant_args, diffeReturnArg, returnValue, ReturnType); + additionalArg, constant_args, diffeReturnArg, returnValue, DReturnType); /* for (Block &BB : F.getFunctionBody().getBlocks()) { @@ -267,5 +269,104 @@ FunctionOpInterface CloneFunctionWithReturns( } } + std::string ToClone[] = { + "bufferization.writable", + "mhlo.sharding", + "mhlo.layout_mode", + "xla_framework.input_mapping", + "xla_framework.result_mapping", + }; + size_t newxlacnt = 0; + { + size_t oldi = 0; + size_t newi = 0; + while (oldi < F.getNumResults()) { + bool primalReturn = returnValue == ReturnType::ArgsWithReturn || + returnValue == ReturnType::ArgsWithTwoReturns || + (returnValue == ReturnType::TapeAndReturn && + DReturnType == DIFFE_TYPE::CONSTANT) || + returnValue == ReturnType::TapeAndTwoReturns || + returnValue == ReturnType::TwoReturns || + (returnValue == ReturnType::Return && + DReturnType == DIFFE_TYPE::CONSTANT); + if (primalReturn) { + for (auto attrName : ToClone) { + auto attrNameS = StringAttr::get(F->getContext(), attrName); + NewF.removeResultAttr(newi, attrNameS); + if (auto attr = F.getResultAttr(oldi, attrName)) { + if (attrName == "xla_framework.result_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setResultAttr(newi, attrNameS, attr); + } + } + newi++; + } + if (DReturnType == DIFFE_TYPE::DUP_ARG || + DReturnType == DIFFE_TYPE::DUP_NONEED) { + for (auto attrName : ToClone) { + auto attrNameS = StringAttr::get(F->getContext(), attrName); + NewF.removeResultAttr(newi, attrNameS); + if (auto attr = F.getResultAttr(oldi, attrName)) { + if (attrName == "xla_framework.result_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setResultAttr(newi, attrNameS, attr); + } + } + newi++; + } + oldi++; + } + } + { + size_t oldi = 0; + size_t newi = 0; + while (oldi < F.getNumArguments()) { + for (auto attrName : ToClone) { + NewF.removeArgAttr(newi, attrName); + if (auto attr = F.getArgAttr(oldi, attrName)) { + if (attrName == "xla_framework.input_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setArgAttr(newi, attrName, attr); + } + } + + newi++; + if (constant_args[oldi] == DIFFE_TYPE::DUP_ARG || + constant_args[oldi] == DIFFE_TYPE::DUP_NONEED) { + + for (auto attrName : ToClone) { + NewF.removeArgAttr(newi, attrName); + if (auto attr = F.getArgAttr(oldi, attrName)) { + if (attrName == "xla_framework.input_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setArgAttr(newi, attrName, attr); + } + } + newi++; + } + oldi++; + } + } + return NewF; } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 7eb015121a3e..b61cb6fa856a 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -855,7 +855,9 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, } else if (opName == "CheckedDiv") { os << "checkedDiv(" << builder << ", "; } else if (intrinsic == MLIRDerivatives) { - os << builder << ".create<" << opName << ">(op.getLoc(), "; + auto dialect = Def->getValueAsString("dialect"); + os << builder << ".create<" << dialect << "::" << opName + << ">(op.getLoc(), "; } else { os << builder << ".Create" << opName << "("; } From f808be8d07a6d35f877a08dedd8eb4ef4caf87ce Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 20:51:26 -0500 Subject: [PATCH 058/106] MLIR: handle multi-result functions (#1723) --- .../Enzyme/MLIR/Interfaces/CloneFunction.cpp | 30 ++++---- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 68 +++++++++++-------- enzyme/test/MLIR/ForwardMode/multiout.mlir | 24 +++++++ 3 files changed, 79 insertions(+), 43 deletions(-) create mode 100644 enzyme/test/MLIR/ForwardMode/multiout.mlir diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 7e802b6a952e..9b5c007c62ee 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -19,22 +19,26 @@ mlir::FunctionType getFunctionTypeForClone( SmallVector RetTypes; if (returnValue == ReturnType::ArgsWithReturn || returnValue == ReturnType::Return) { - assert(FTy.getNumResults() == 1); - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(0), width)); - } else { - RetTypes.push_back(FTy.getResult(0)); + assert(FTy.getNumResults() >= 1); + for (size_t i = 0; i < FTy.getNumResults(); i++) { + if (ReturnType != DIFFE_TYPE::CONSTANT && + ReturnType != DIFFE_TYPE::OUT_DIFF) { + RetTypes.push_back(getShadowType(FTy.getResult(i), width)); + } else { + RetTypes.push_back(FTy.getResult(i)); + } } } else if (returnValue == ReturnType::ArgsWithTwoReturns || returnValue == ReturnType::TwoReturns) { - assert(FTy.getNumResults() == 1); - RetTypes.push_back(FTy.getResult(0)); - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(0), width)); - } else { - RetTypes.push_back(FTy.getResult(0)); + assert(FTy.getNumResults() >= 1); + for (size_t i = 0; i < FTy.getNumResults(); i++) { + RetTypes.push_back(FTy.getResult(i)); + if (ReturnType != DIFFE_TYPE::CONSTANT && + ReturnType != DIFFE_TYPE::OUT_DIFF) { + RetTypes.push_back(getShadowType(FTy.getResult(i), width)); + } else { + RetTypes.push_back(FTy.getResult(i)); + } } } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 1da6904893f8..064b6f181c95 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -46,44 +46,52 @@ void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, switch (retVal) { case ReturnType::Return: { - auto ret = inst->getOperand(0); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); - } else { - Type retTy = ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue(nBuilder, - ret.getLoc()); + for (size_t i = 0; i < inst->getNumOperands(); i++) { + auto ret = inst->getOperand(i); + + mlir::Value toret; + if (retType == DIFFE_TYPE::CONSTANT) { + toret = gutils->getNewFromOriginal(ret); + } else if (!isa(ret.getType()) && + true /*type analysis*/) { + toret = gutils->invertPointerM(ret, nBuilder); + } else if (!gutils->isConstantValue(ret)) { + toret = gutils->invertPointerM(ret, nBuilder); + } else { + Type retTy = + ret.getType().cast().getShadowType(); + toret = retTy.cast().createNullValue( + nBuilder, ret.getLoc()); + } + retargs.push_back(toret); } - retargs.push_back(toret); break; } case ReturnType::TwoReturns: { if (retType == DIFFE_TYPE::CONSTANT) assert(false && "Invalid return type"); - auto ret = inst->getOperand(0); - - retargs.push_back(gutils->getNewFromOriginal(ret)); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); - } else { - Type retTy = ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue(nBuilder, - ret.getLoc()); + for (size_t i = 0; i < inst->getNumOperands(); i++) { + auto ret = inst->getOperand(i); + + retargs.push_back(gutils->getNewFromOriginal(ret)); + + mlir::Value toret; + if (retType == DIFFE_TYPE::CONSTANT) { + toret = gutils->getNewFromOriginal(ret); + } else if (!isa(ret.getType()) && + true /*type analysis*/) { + toret = gutils->invertPointerM(ret, nBuilder); + } else if (!gutils->isConstantValue(ret)) { + toret = gutils->invertPointerM(ret, nBuilder); + } else { + Type retTy = + ret.getType().cast().getShadowType(); + toret = retTy.cast().createNullValue( + nBuilder, ret.getLoc()); + } + retargs.push_back(toret); } - retargs.push_back(toret); break; } case ReturnType::Void: { diff --git a/enzyme/test/MLIR/ForwardMode/multiout.mlir b/enzyme/test/MLIR/ForwardMode/multiout.mlir new file mode 100644 index 000000000000..e42322a7f74f --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/multiout.mlir @@ -0,0 +1,24 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> (f64, f64) { + %y = arith.mulf %x, %x : f64 + return %y, %x : f64, f64 + } + func.func @dsq(%x : f64, %dx : f64) -> (f64, f64) { + %r:2 = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64, f64) + return %r#0, %r#1 : f64, f64 + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { +// CHECK-NEXT: %[[i0:[0-9]+]]:2 = call @fwddiffesquare(%[[arg0]], %[[arg1]]) : (f64, f64) -> (f64, f64) +// CHECK-NEXT: return %[[i0]]#0, %[[i0]]#1 : f64, f64 +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: return %[[i2]], %[[arg1]] : f64, f64 +// CHECK-NEXT: } From f463384c7f3ae601db25acdb9213e03f7a4daaba Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 21:17:15 -0500 Subject: [PATCH 059/106] MLIR better error for incorrect arg activity count (#1724) --- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 3 +++ enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 064b6f181c95..edd21fbbf1ea 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -126,6 +126,9 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); } + assert(fn.getFunctionBody().front().getNumArguments() == constants.size()); + assert(fn.getFunctionBody().front().getNumArguments() == + volatile_args.size()); MForwardCacheKey tup = { fn, retType, constants, diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index e50306b4a32b..1b37e03383a8 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -71,6 +71,13 @@ struct DifferentiateWrapperPass } } + if (constants.size() != fn.getFunctionBody().front().getNumArguments()) { + fn->emitError() + << "Incorrect number of arg activity states for function, found " + << split; + return; + } + DIFFE_TYPE retType = retTy.getValue(); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); @@ -108,7 +115,6 @@ struct DifferentiateWrapperPass namespace mlir { namespace enzyme { std::unique_ptr createDifferentiateWrapperPass() { - new DifferentiateWrapperPass(); return std::make_unique(); } } // namespace enzyme From 2c753c97fcb41623e9aca972edfc08202b23e04f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 22:56:23 -0500 Subject: [PATCH 060/106] MLIR fix activity analysis assertion error (#1725) --- enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 349213ac2fbf..06b3257cd3a8 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -2598,12 +2598,13 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( // return true; // } Value callVal = call.getCallableForCallee().dyn_cast(); - if (isConstantValue(TR, callVal)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-constfn " - // << *inst << " - " << *callVal << "\n"; - return true; - } + if (callVal) + if (isConstantValue(TR, callVal)) { + // if (EnzymePrintActivity) + // llvm::errs() << "constant(" << (int)directions << ") up-constfn " + // << *inst << " - " << *callVal << "\n"; + return true; + } } if (auto gep = dyn_cast(op)) { From bf0261421bdbb78d61c0117aef4ca7dbd1f30575 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 15 Feb 2024 23:58:48 -0500 Subject: [PATCH 061/106] MLIR constantfp (#1726) --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 29 ++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index b61cb6fa856a..fd2ab67ee725 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -449,17 +449,30 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") + resultTree->getAsString()); - os << "ConstantFP::get("; - if (resultRoot->getArgName(0)) { + if (intrinsic == MLIRDerivatives) { + os << builder << ".create<" + << cast(Def->getValueInit("dialect"))->getValue() + << "::" << cast(Def->getValueInit("opName"))->getValue() + << ">(op.getLoc(), "; auto name = resultRoot->getArgName(0)->getAsUnquotedString(); auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); assert(!isVec); - os << ord; - } else - PrintFatalError(pattern->getLoc(), - Twine("unknown named operand in constantfp") + - resultTree->getAsString()); - os << "->getType(), \"" << value->getValue() << "\")"; + os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), "; + os << "\"" << value->getValue() << "\"))"; + } else { + + os << "ConstantFP::get("; + if (resultRoot->getArgName(0)) { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); + assert(!isVec); + os << ord; + } else + PrintFatalError(pattern->getLoc(), + Twine("unknown named operand in constantfp") + + resultTree->getAsString()); + os << "->getType(), \"" << value->getValue() << "\")"; + } return false; } else if (opName == "Zero" || Def->isSubClassOf("Zero")) { if (resultRoot->getNumArgs() != 1) From 4ccab29dc691cb43d250a7c5ca612c3ff9cd23e3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 16 Feb 2024 01:04:31 -0500 Subject: [PATCH 062/106] MLIR improve error handling --- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 7 +++++- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 3 ++- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 24 +++++++++++++++---- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index edd21fbbf1ea..3025b5963fcf 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -183,6 +183,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( unnecessaryInstructions, gutils, TLI); */ + bool valid = true; for (Block &oBB : gutils->oldFunc.getFunctionBody().getBlocks()) { // Don't create derivatives for code that results in termination if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { @@ -205,7 +206,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end()); for (auto it = first; it != last; ++it) { // TODO: propagate errors. - (void)gutils->visitChild(&*it); + auto res = gutils->visitChild(&*it); + valid &= res.succeeded(); } createTerminator(gutils, &oBB, retType, returnValue); @@ -232,6 +234,9 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto nf = gutils->newFunc; delete gutils; + if (!valid) + return nullptr; + // if (PostOpt) // PPC.optimizeIntermediate(nf); // if (EnzymePrint) { diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 467a8f59ec6e..45473c75921f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -318,5 +318,6 @@ LogicalResult MGradientUtils::visitChild(Operation *op) { return iface.createForwardModeTangent(builder, this); } } - return op->emitError() << "could not compute the adjoint for this operation"; + return op->emitError() << "could not compute the adjoint for this operation " + << *op; } diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index efe4b7d53f76..1e2c55640280 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -35,7 +35,7 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; template - void HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { + LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { std::vector constants; SmallVector args; @@ -83,16 +83,20 @@ struct DifferentiatePass : public DifferentiatePassBase { /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); + if (!newFunc) + return failure(); OpBuilder builder(CI); auto dCI = builder.create(CI.getLoc(), newFunc.getName(), newFunc.getResultTypes(), args); CI.replaceAllUsesWith(dCI); CI->erase(); + return success(); } template - void HandleAutoDiffReverse(SymbolTableCollection &symbolTable, T CI) { + LogicalResult HandleAutoDiffReverse(SymbolTableCollection &symbolTable, + T CI) { std::vector constants; SmallVector args; @@ -144,12 +148,15 @@ struct DifferentiatePass : public DifferentiatePassBase { /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr, symbolTable); + if (!newFunc) + return failure(); OpBuilder builder(CI); auto dCI = builder.create(CI.getLoc(), newFunc.getName(), newFunc.getResultTypes(), args); CI.replaceAllUsesWith(dCI); CI->erase(); + return success(); } void lowerEnzymeCalls(SymbolTableCollection &symbolTable, @@ -167,7 +174,11 @@ struct DifferentiatePass : public DifferentiatePassBase { for (auto T : toLower) { if (auto F = dyn_cast(T)) { - HandleAutoDiff(symbolTable, F); + auto res = HandleAutoDiff(symbolTable, F); + if (!res.succeeded()) { + signalPassFailure(); + return; + } } else { llvm_unreachable("Illegal type"); } @@ -187,7 +198,11 @@ struct DifferentiatePass : public DifferentiatePassBase { for (auto T : toLower) { if (auto F = dyn_cast(T)) { - HandleAutoDiffReverse(symbolTable, F); + auto res = HandleAutoDiffReverse(symbolTable, F); + if (!res.succeeded()) { + signalPassFailure(); + return; + } } else { llvm_unreachable("Illegal type"); } @@ -201,7 +216,6 @@ struct DifferentiatePass : public DifferentiatePassBase { namespace mlir { namespace enzyme { std::unique_ptr createDifferentiatePass() { - new DifferentiatePass(); return std::make_unique(); } } // namespace enzyme From dbfa740fa8c170f828a8396267eafff169779b42 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 16 Feb 2024 11:28:38 -0500 Subject: [PATCH 063/106] MLIR improve tblgen --- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 4 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 18 ++++++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 1b37e03383a8..b8d91f2c82e8 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -97,6 +97,10 @@ struct DifferentiateWrapperPass width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); + if (!newFunc) { + signalPassFailure(); + return; + } if (outfn == "") { fn->erase(); SymbolTable::setSymbolVisibility(newFunc, diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index fd2ab67ee725..1e1ac3697095 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -454,9 +454,15 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, << cast(Def->getValueInit("dialect"))->getValue() << "::" << cast(Def->getValueInit("opName"))->getValue() << ">(op.getLoc(), "; - auto name = resultRoot->getArgName(0)->getAsUnquotedString(); - auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); - assert(!isVec); + std::string ord; + if (resultRoot->getNumArgs() == 0) { + ord = "op->getResult(0)"; + } else { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord1, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); + assert(!isVec); + ord = ord1; + } os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), "; os << "\"" << value->getValue() << "\"))"; } else { @@ -1636,7 +1642,11 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } } else { - os << " Value *res = "; + if (intrinsic == MLIRDerivatives) { + os << " mlir::Value res = nullptr;\n"; + } else { + os << " Value *res = "; + } ArrayRef retidx{}; bool vectorValued = handle(" ", "fwdnsrarg", os, pattern, duals, "Builder2", From d158f8350741d7a07e26512ecb319c6986e21e1e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 16 Feb 2024 19:52:54 +0000 Subject: [PATCH 064/106] [bazel] export common derivative definitions --- enzyme/BUILD | 7 +++++++ enzyme/Enzyme/MLIR/Implementations/Common.td | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/enzyme/BUILD b/enzyme/BUILD index 9bf6fadabe9a..7c78c75a4051 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -389,6 +389,13 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) +td_library( + name = "ImplementationsCommonTdFiles", + srcs = [ + "Enzyme/MLIR/Implementations/Common.td", + ], +) + gentbl( name = "affine-derivatives", tbl_outs = [( diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 70f4c39f99d6..f558d1225a76 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -1,3 +1,6 @@ +#ifndef ENZYME_MLIR_IMPLEMENTATIONS_COMMON +#define ENZYME_MLIR_IMPLEMENTATIONS_COMMON + class InactiveOp { string dialect = dialect_; string opName = opName_; @@ -76,3 +79,5 @@ def CosF : MathInst<"CosOp">; def SinF : MathInst<"SinOp">; def ExpF : MathInst<"ExpOp">; def SqrtF : MathInst<"SqrtOp">; + +#endif // ENZYME_MLIR_IMPLEMENTATIONS_COMMON From 2eb83870cbd8b26fc969ea99e691609becb1fb91 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 16 Feb 2024 14:59:15 -0500 Subject: [PATCH 065/106] MLIR more error messages for interface internals --- .../CoreDialectsAutoDiffImplementations.cpp | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 3bf48fc16b81..c18a7bd0e921 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -136,7 +136,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( } } } - orig->emitWarning() + orig->emitError() << "Unsupported constant arg to memory identity forward " "handler(opidx=" << operand.getOperandNumber() << ", op=" << operand.get() << ")\n"; @@ -175,8 +175,9 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler( if (auto iface = dyn_cast(shadowRes.getType())) { return iface.zeroInPlace(builder, orig->getLoc(), shadowRes); } else { - orig->emitWarning() << "memref.alloc element type does not implement " - "AutoDiffTypeInterface"; + orig->emitError() << "Type " << shadowRes.getType() + << " does not implement " + "AutoDiffTypeInterface"; return failure(); } } @@ -235,16 +236,22 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( if (gutils->isConstantValue(result)) continue; auto typeIface = dyn_cast(result.getType()); - if (!typeIface) + if (!typeIface) { + op->emitError() << " AutoDiffTypeInterface not implemented for " + << result.getType() << "\n"; return failure(); + } newOpResultTypes.push_back(typeIface.getShadowType()); } // For all operands that are forwarded to the body, if they are active, also // add the shadow as operand. auto regionBranchOp = dyn_cast(op); - if (!regionBranchOp) + if (!regionBranchOp) { + op->emitError() << " RegionBranchOpInterface not implemented for " << *op + << "\n"; return failure(); + } SmallVector successors; // TODO: we may need to record, for every successor, which of its inputs @@ -283,8 +290,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( // yielded by terminators, and only those values. auto iface = dyn_cast(op); - if (!iface) + if (!iface) { + op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for " + << *op << "\n"; return failure(); + } Operation *replacement = iface.createWithShadows( builder, gutils, op, newOperands, newOpResultTypes); for (auto &&[region, replacementRegion] : @@ -314,8 +324,9 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( for (auto &origRegion : op->getRegions()) { for (auto &origBlock : origRegion) { for (Operation &o : origBlock) { - if (failed(gutils->visitChild(&o))) + if (failed(gutils->visitChild(&o))) { return failure(); + } } } } From 26ef37ea96d4f94a24409d5e59329f1255c0ac95 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Fri, 16 Feb 2024 17:57:24 -0500 Subject: [PATCH 066/106] Add smax, smin, umin, umax to type analysis (#1634) * Add s/u/min/max * Add opaque pointer in test * Update test and llvm version * Add s/u/min/max * Add opaque pointer in test * Update test * Pointer type case * Update test * Update test * Lost line * Break missing * Update logic * Update function * Update * Update test * Update test * gt llvm 11 * Add version check in test * Update test --- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 29 ++++++++++++++++++++- enzyme/test/TypeAnalysis/smax.ll | 26 ++++++++++++++++++ enzyme/test/TypeAnalysis/smax0.ll | 23 ++++++++++++++++ 3 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 enzyme/test/TypeAnalysis/smax.ll create mode 100644 enzyme/test/TypeAnalysis/smax0.ll diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 851db780723d..04d1822d4e23 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -3892,7 +3892,34 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { .Only(-1, &I), &I); return; - +#if LLVM_VERSION_MAJOR >= 12 + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: + if (direction & UP) { + auto returnType = getAnalysis(&I)[{-1}]; + if (returnType == BaseType::Integer || returnType == BaseType::Pointer) { + updateAnalysis(I.getOperand(0), TypeTree(returnType).Only(-1, &I), &I); + updateAnalysis(I.getOperand(1), TypeTree(returnType).Only(-1, &I), &I); + } + } + if (direction & DOWN) { + auto opType0 = getAnalysis(I.getOperand(0))[{-1}]; + auto opType1 = getAnalysis(I.getOperand(1))[{-1}]; + if (opType0 == opType1 && + (opType0 == BaseType::Integer || opType0 == BaseType::Pointer)) { + updateAnalysis(&I, TypeTree(opType0).Only(-1, &I), &I); + } else if (opType0 == BaseType::Integer && + opType1 == BaseType::Anything) { + updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); + } else if (opType1 == BaseType::Integer && + opType0 == BaseType::Anything) { + updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); + } + } + return; +#endif case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: case Intrinsic::ssub_with_overflow: diff --git a/enzyme/test/TypeAnalysis/smax.ll b/enzyme/test/TypeAnalysis/smax.ll new file mode 100644 index 000000000000..b050ccc9ef18 --- /dev/null +++ b/enzyme/test/TypeAnalysis/smax.ll @@ -0,0 +1,26 @@ +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=smax -o /dev/null | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=smax -S -o /dev/null | FileCheck %s; fi + +define i32 @smax(i32 %a, i32 %b) { +entry: + %0 = call i32 @llvm.smax.i32(i32 %a, i32 %b) + %1 = call i32 @getint() + %2 = call i32 @getint() + %3 = call i32 @llvm.smax.i32(i32 %1, i32 %2) + ret i32 %3 +} + +declare i32 @llvm.smax.i32(i32, i32) + +declare i32 @getint() + + +; CHECK: smax - {[-1]:Integer} |{[-1]:Integer}:{} {[-1]:Integer}:{} +; CHECK-NEXT: i32 %a: {[-1]:Integer} +; CHECK-NEXT: i32 %b: {[-1]:Integer} +; CHECK-NEXT: entry +; CHECK-NEXT: %0 = call i32 @llvm.smax.i32(i32 %a, i32 %b): {[-1]:Integer} +; CHECK-NEXT: %1 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %2 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %3 = call i32 @llvm.smax.i32(i32 %1, i32 %2): {[-1]:Integer} +; CHECK-NEXT: ret i32 %3: {} diff --git a/enzyme/test/TypeAnalysis/smax0.ll b/enzyme/test/TypeAnalysis/smax0.ll new file mode 100644 index 000000000000..de79bcde70bb --- /dev/null +++ b/enzyme/test/TypeAnalysis/smax0.ll @@ -0,0 +1,23 @@ +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=smax0 -o /dev/null | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=smax0 -S -o /dev/null | FileCheck %s; fi + +define i32 @smax0(i32 %a, i32 %b) { +entry: + %0 = call i32 @llvm.smax.i32(i32 %a, i32 0) + %1 = call i32 @getint() + %2 = call i32 @llvm.smax.i32(i32 %1, i32 0) + ret i32 %2 +} + +declare i32 @llvm.smax.i32(i32, i32) + +declare i32 @getint() + +; CHECK: smax0 - {[-1]:Integer} |{[-1]:Integer}:{} {[-1]:Integer}:{} +; CHECK-NEXT: i32 %a: {[-1]:Integer} +; CHECK-NEXT: i32 %b: {[-1]:Integer} +; CHECK-NEXT: entry +; CHECK-NEXT: %0 = call i32 @llvm.smax.i32(i32 %a, i32 0): {[-1]:Integer} +; CHECK-NEXT: %1 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %2 = call i32 @llvm.smax.i32(i32 %1, i32 0): {[-1]:Integer} +; CHECK-NEXT: ret i32 %2: {} From 12030cd074a6a167807de8eff733ee9b67e2a42f Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 16 Feb 2024 18:43:48 -0500 Subject: [PATCH 067/106] Update opaque pointer test (#1732) --- enzyme/test/Enzyme/ReverseMode/mlirmincut.ll | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll index 6f214ddb2cdc..dc0e9599fe22 100644 --- a/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll +++ b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll @@ -1,4 +1,4 @@ -; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %OPnewLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi declare void @__enzyme_autodiff0(...) local_unnamed_addr From 616a88cfa03c36f61afa325f03582935b271d307 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 17 Feb 2024 21:32:51 -0500 Subject: [PATCH 068/106] MLIR add more general control flow handler interface (#1731) * MLIR add more general control flow handler interface * fixup act * more activ improvements --- .../Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 1029 +++++++++-------- .../Enzyme/MLIR/Analysis/ActivityAnalysis.h | 10 +- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 8 +- .../CoreDialectsAutoDiffImplementations.cpp | 80 +- .../CoreDialectsAutoDiffImplementations.h | 9 + .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 6 + .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 11 +- enzyme/Enzyme/MLIR/Passes/Passes.td | 7 + .../MLIR/Passes/PrintActivityAnalysis.cpp | 85 +- enzyme/test/MLIR/ActivityAnalysis/region.mlir | 27 + 10 files changed, 720 insertions(+), 552 deletions(-) create mode 100644 enzyme/test/MLIR/ActivityAnalysis/region.mlir diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 06b3257cd3a8..30fbfc8ef345 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -248,6 +248,8 @@ static Operation *getFunctionFromCall(CallOpInterface iface) { return SymbolTable::lookupNearestSymbolFrom(iface.getOperation(), symbol); } +constexpr bool EnzymePrintActivity = true; + /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode bool mlir::enzyme::ActivityAnalyzer::isFunctionArgumentConstant( @@ -465,7 +467,7 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // The return instruction doesn't impact activity (handled specifically // during adjoint generation) - if (isa(I)) + if (I->hasTrait()) return true; if (auto ifaceOp = dyn_cast(I)) { @@ -485,9 +487,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, } if (notForAnalysis.count(I->getBlock())) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as dominates unreachable " << *I - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as dominates unreachable " << *I + << "\n"; InsertConstantOperation(TR, I); return true; } @@ -495,14 +497,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (auto CI = dyn_cast(I)) { // TODO(PR #904): This needs to be put into the enzyme dialect if (CI->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active " << *I << "\n"; ActiveOperations.insert(I); return false; } if (CI->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -510,14 +512,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (called) { if (called->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active " << *I << "\n"; ActiveOperations.insert(I); return false; } if (called->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -683,11 +685,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // If all returned values constant otherwise, the operation is inactive if (llvm::all_of(I->getResults(), [&](Value v) { return isConstantValue(TR, v); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from known constant - // non-writing " - // "instruction " - // << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction from known constant non-writing " + "instruction " + << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -710,9 +711,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (llvm::all_of(I->getResults(), [&](Value val) { return isValueInactiveFromUsers(TR, val, UseActivity::None); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction[" << (int)directions - // << "] from users instruction " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction[" << (int)directions + << "] from users instruction " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -724,9 +725,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, return DownHypothesis->isValueInactiveFromUsers( TR, val, UseActivity::None); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction[" << (int)directions - // << "] from users instruction " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction[" << (int)directions + << "] from users instruction " << *I << "\n"; InsertConstantOperation(TR, I); insertConstantsFrom(TR, *DownHypothesis); return true; @@ -748,65 +749,42 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, new mlir::enzyme::ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantOperations.insert(I); assert(directions & UP); - if (UpHypothesis->isOperationInactiveFromOrigin(TR, I)) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from origin " - // "instruction " - // << *I << "\n"; + SmallPtrSet toredo; + if (UpHypothesis->isOperationInactiveFromOrigin(TR, I, std::nullopt, + &toredo)) { + if (EnzymePrintActivity) + llvm::errs() << " constant instruction from origin " + "instruction " + << *I << "\n"; InsertConstantOperation(TR, I); insertConstantsFrom(TR, *UpHypothesis); if (DownHypothesis) insertConstantsFrom(TR, *DownHypothesis); return true; } else if (directions == (UP | DOWN)) { - // TODO: what does this mean for interfaces? - if (isa< - // clang-format off - LLVM::LoadOp, - LLVM::StoreOp, - // Integer binary ops. - LLVM::AddOp, - LLVM::SubOp, - LLVM::MulOp, - LLVM::UDivOp, - LLVM::SDivOp, - LLVM::URemOp, - LLVM::SRemOp, - LLVM::AndOp, - LLVM::OrOp, - LLVM::XOrOp, - LLVM::ShlOp, - LLVM::LShrOp, - LLVM::AShrOp, - // Float binary ops. - LLVM::FAddOp, - LLVM::FSubOp, - LLVM::FMulOp, - LLVM::FDivOp, - LLVM::FRemOp, - LLVM::FNegOp - // clang-format on - >(I)) { - for (Value operand : I->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - ReEvaluateOpIfInactiveValue[operand].insert(I); - } - } + for (Value operand : toredo) { + ReEvaluateOpIfInactiveValue[operand].insert(I); } } } // Otherwise we must fall back and assume this instruction to be active. ActiveOperations.insert(I); - // if (EnzymePrintActivity) - // llvm::errs() << "couldnt decide fallback as nonconstant instruction(" - // << (int)directions << "):" << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "couldnt decide fallback as nonconstant instruction(" + << (int)directions << "):" << *I << "\n"; if (noActiveWrite && (directions == (UP | DOWN))) for (Value result : I->getResults()) ReEvaluateOpIfInactiveValue[result].insert(I); return false; } +static bool isFunctionReturn(Operation *op) { + if (!op->hasTrait()) + return false; + return dyn_cast(op->getParentOp()); +} + static bool isValuePotentiallyUsedAsPointer(Value val) { std::deque todo = {val}; SmallPtrSet seen; @@ -817,7 +795,39 @@ static bool isValuePotentiallyUsedAsPointer(Value val) { continue; seen.insert(cur); for (Operation *user : cur.getUsers()) { - if (isa(user)) + if (isa(user->getParentOp())) + if (auto termIface = + dyn_cast(user)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + auto parentOp = termIface->getParentOp(); + for (auto &successor : successors) { + OperandRange operandRange = + termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) { + if (prev == cur) { + todo.push_back(post); + } + } + } + continue; + } + if (auto iface = dyn_cast(user)) { + for (auto &op : user->getOpOperands()) + if (op.get() == cur) + if (auto blk = + iface.getSuccessorBlockArgument(op.getOperandNumber())) + todo.push_back(*blk); + continue; + } + if (isFunctionReturn(user)) return true; // The operation is known not to read or write memory. if (isa(user) && @@ -828,10 +838,9 @@ static bool isValuePotentiallyUsedAsPointer(Value val) { } continue; } - // if (EnzymePrintActivity) - // llvm::errs() << " VALUE potentially used as pointer " << *val << " by - // " - // << *u << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " VALUE potentially used as pointer " << val << " by " + << *user << "\n"; return true; } } @@ -889,7 +898,134 @@ static FunctionOpInterface getFunctionIfArgument(Value value) { return dyn_cast(block->getParentOp()); } -// TODO: move the extraction based on dataflow here. +// For a given instruction, determine whether it is a terminator which +// controls dataflow out, and if so return all users either in results +// or blockarguments +static std::optional> +getPotentialTerminatorUsers(Operation *op, Value parent) { + auto block = op->getBlock(); + + if (block->getTerminator() != op) + return {}; + if (isFunctionReturn(op)) + return {}; + + SmallVector results; + + if (isa(op->getParentOp())) + if (auto termIface = dyn_cast(op)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + auto parentOp = termIface->getParentOp(); + SmallVector results; + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) { + if (prev == parent) { + results.push_back(post); + } + } + } + return std::move(results); + } + if (auto iface = dyn_cast(op)) { + for (auto &operand : op->getOpOperands()) + if (operand.get() == parent) + if (auto blk = + iface.getSuccessorBlockArgument(operand.getOperandNumber())) { + results.push_back(*blk); + return std::move(results); + } + } + + // assume all terminator operands potentially flow into all op results + for (auto res : op->getParentOp()->getResults()) + results.push_back(res); + + // assume all terminator operands potentially flow into all blockArgs in + // region + for (auto &blk : *block->getParent()) + for (auto arg : blk.getArguments()) + results.push_back(arg); + + // assume all terminator operands potentially flow into all other region + // entries + for (auto ® : op->getParentOp()->getRegions()) + for (auto arg : reg.front().getArguments()) + results.push_back(arg); + + return std::move(results); +} + +// For a result of an op, find all values which could flow into this result +static SmallVector getPotentialIncomingValues(OpResult res) { + Operation *owner = res.getOwner(); + SmallVector potentialSources; + + auto resultNo = res.getResultNumber(); + + if (auto iface = dyn_cast(owner)) { + SmallVector successors; + iface.getSuccessorRegions(RegionBranchPoint::parent(), successors); + for (auto &succ : successors) { + if (!succ.isParent()) + continue; + auto successorOperands = + llvm::to_vector(iface.getEntrySuccessorOperands(succ)); + + if (successorOperands.size() != owner->getNumResults()) { + llvm::errs() << *owner << "\n"; + } + assert(successorOperands.size() == owner->getNumResults() && + "expected all results to be populated with incoming operands"); + + potentialSources.push_back(successorOperands[resultNo]); + } + } else { + // assume all inputs potentially flow into all op results + for (auto operand : owner->getOperands()) { + potentialSources.push_back(operand); + } + } + + for (Region ®ion : owner->getRegions()) { + for (Block &block : region) { + // TODO: MLIR blocks without terminator? + if (auto iface = dyn_cast( + block.getTerminator())) { + // TODO: the interface may also tell us which regions are allowed to + // yield parent op results, and which only branch to other regions. + auto successorOperands = llvm::to_vector( + iface.getSuccessorOperands(RegionBranchPoint::parent())); + // TODO: understand/document the assumption of how operands flow. + + if (successorOperands.size() != owner->getNumResults()) { + llvm::errs() << *owner << "\n"; + } + assert(successorOperands.size() == owner->getNumResults() && + "expected all results to be populated with yielded " + "terminator operands"); + potentialSources.push_back(successorOperands[resultNo]); + } else { + // assume all terminator operands potentially flow into op results + for (Value v : block.getTerminator()->getOperands()) + potentialSources.push_back(v); + } + } + } + + return potentialSources; +} + +// For a blockargument, find all non-operand values which could flow into +// this result static SmallVector getPotentialIncomingValues(BlockArgument arg) { SetVector potentialSources; @@ -986,6 +1122,10 @@ static SmallVector getPotentialIncomingValues(BlockArgument arg) { potentialSources.insert(v); } } + + // and also any operand to the parent + for (auto op : parent->getOperands()) + potentialSources.insert(op); } return potentialSources.takeVector(); @@ -1144,8 +1284,16 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } if (auto arg = Val.dyn_cast()) { - auto funcIface = dyn_cast_or_null( - arg.getParentBlock()->getParentOp()); + // All arguments must be marked constant/nonconstant ahead of time + if (auto funcIface = dyn_cast_or_null( + arg.getParentBlock()->getParentOp())) + if (funcIface && arg.getOwner()->isEntryBlock() && + !funcIface.getArgAttr(arg.getArgNumber(), + LLVM::LLVMDialect::getByValAttrName())) { + llvm::errs() << funcIface << "\n"; + llvm::errs() << Val << "\n"; + assert(0 && "must've put arguments in constant/nonconstant"); + } // if (!funcIface || !arg.getOwner()->isEntryBlock()) { // TODO: we want a more advanced analysis based on MLIR interfaces here // For now, conservatively assume all block arguments are active @@ -1174,25 +1322,15 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // } // } // } - - // All arguments must be marked constant/nonconstant ahead of time - if (funcIface && arg.getOwner()->isEntryBlock() && - !funcIface.getArgAttr(arg.getArgNumber(), - LLVM::LLVMDialect::getByValAttrName())) { - llvm::errs() << funcIface << "\n"; - llvm::errs() << Val << "\n"; - assert(0 && "must've put arguments in constant/nonconstant"); - } } // This value is certainly an integer (and only and integer, not a pointer or // float). Therefore its value is constant if (TR.intType(1, Val, /*errIfNotFound*/ false).isIntegral()) { - // if (EnzymePrintActivity) - // llvm::errs() << " Value const as integral " << (int)directions << " " - // << *Val << " " - // << TR.intType(1, Val, /*errIfNotFound*/ false).str() << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Value const as integral " << (int)directions << " " + << Val << " " + << TR.intType(1, Val, /*errIfNotFound*/ false).str() << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1368,28 +1506,28 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (auto CI = Val.getDefiningOp()) { if (CI->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active val " << Val << "\n"; ActiveValues.insert(Val); return false; } if (CI->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive val " << Val << "\n"; InsertConstantValue(TR, Val); return true; } Operation *called = getFunctionFromCall(CI); if (called) { if (called->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active val " << Val << "\n"; ActiveValues.insert(Val); return false; } if (called->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive val " << Val << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1440,14 +1578,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, LLVM::LLVMDialect::getByValAttrName())) { bool res = isConstantValue(TR, TmpOrig); if (res) { - // if (EnzymePrintActivity) - // llvm::errs() << " arg const from orig val=" << *Val - // << " orig=" << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " arg const from orig val=" << Val + << " orig=" << TmpOrig << "\n"; InsertConstantValue(TR, Val); } else { - // if (EnzymePrintActivity) - // llvm::errs() << " arg active from orig val=" << *Val - // << " orig=" << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " arg active from orig val=" << Val + << " orig=" << TmpOrig << "\n"; ActiveValues.insert(Val); } return res; @@ -1722,10 +1860,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // argument if (TmpOrig != Val) { if (isConstantValue(TR, TmpOrig)) { - // if (EnzymePrintActivity) - // llvm::errs() << " Potential Pointer(" << (int)directions << ") " - // << *Val << " inactive from inactive origin " - // << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Potential Pointer(" << (int)directions << ") " + << Val << " inactive from inactive origin " << TmpOrig + << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1737,11 +1875,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (!op || (!mayReadFromMemory(op) && !mayAllocateMemory(op))) { if (directions == UP && !Val.isa()) { if (isValueInactiveFromOrigin(TR, Val)) { + if (EnzymePrintActivity) + llvm::errs() << " Non-function value inactive from origin(" + << (int)directions << ") " << Val << "\n"; InsertConstantValue(TR, Val); return true; } } else { if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { + if (EnzymePrintActivity) + llvm::errs() << " Non-function value_v2 inactive from origin(" + << (int)directions << ") " << Val << "\n"; InsertConstantValue(TR, Val); insertConstantsFrom(TR, *UpHypothesis); return true; @@ -1755,9 +1899,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // can be loaded/stored cannot be assesed and therefore we default to assume // it to be active if (directions != (UP | DOWN)) { - // if (EnzymePrintActivity) - // llvm::errs() << " " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << Val << "\n"; ActiveValues.insert(Val); return false; } @@ -1785,13 +1929,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { Hypothesis->DeducingPointers.insert(Val); - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction hypothesis: " << *VI << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction hypothesis: " << Val << "\n"; } else { - // if (EnzymePrintActivity) - // llvm::errs() << " cannot show constant instruction hypothesis: " - // << *VI << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " cannot show constant instruction hypothesis: " + << Val << "\n"; } } @@ -1920,8 +2063,8 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If we haven't already shown a potentially active load // check if this loads the given value and is active if (!potentiallyActiveLoad && isRefSet(modRef)) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active load: " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active load: " << *op << "\n"; if (isa(op)) { // TODO: this assumption should be built into the MLIR interface // verifier, or alternatively we should relax it. @@ -1945,12 +2088,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, for (Operation *user : V.getUsers()) { if (mayWriteToMemory(user)) { if (!Hypothesis->isConstantOperation(TR, user)) { - // if (EnzymePrintActivity) - // llvm::errs() - // << "potential active store via " - // "pointer in load: " - // << *I << " of " << *Val << " via " << *U << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store via " + "pointer in load: " + << *op << " of " << Val << " via " + << *user << "\n"; potentiallyActiveStore = true; return true; } @@ -2031,10 +2173,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, !Hypothesis->isConstantValue(TR, V) && TR.query(V)[{-1}].isPossiblePointer(); })) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active store via pointer in " - // "unknown inst: " - // << *I << " of " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store via pointer in " + "unknown inst: " + << *op << " of " << Val << "\n"; potentiallyActiveStore = true; } } @@ -2042,25 +2184,25 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } } if ((!potentiallyActiveStore || !potentialStore) && isModSet(modRef)) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active store: " << *I << " Val=" << *Val - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store: " << *op << " Val=" << Val + << "\n"; if (auto SI = dyn_cast(op)) { bool cop = !Hypothesis->isConstantValue(TR, SI.getValue()); - // if (EnzymePrintActivity) - // llvm::errs() << " -- store potential activity: " << (int)cop - // << " - " << *SI << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- store potential activity: " << (int)cop + << " - " << *SI << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; } else if (auto SI = dyn_cast(op)) { // FIXME: this is a copy-pasta form above to work with MLIR memrefs. bool cop = !Hypothesis->isConstantValue(TR, SI.getValueToStore()); - // if (EnzymePrintActivity) - // llvm::errs() << " -- store potential activity: " << (int)cop - // << " - " << *SI << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- store potential activity: " << (int)cop + << " - " << *SI << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; @@ -2077,11 +2219,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // TODO: note that this can be optimized (especially for function // calls) auto cop = !Hypothesis->isConstantOperation(TR, op); - // if (EnzymePrintActivity) - // llvm::errs() << " -- unknown store potential activity: " << - // (int)cop - // << " - " << *I << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- unknown store potential activity: " << (int)cop + << " - " << *op << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; @@ -2146,11 +2287,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } activeLoadAndStore:; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *Val - // << " potentiallyActiveLoad=" << potentiallyActiveLoad - // << " potentiallyActiveStore=" << potentiallyActiveStore - // << " potentialStore=" << potentialStore << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << Val + << " potentiallyActiveLoad=" << potentiallyActiveLoad + << " potentiallyActiveStore=" << potentiallyActiveStore + << " potentialStore=" << potentialStore << "\n"; if (potentiallyActiveLoad && potentiallyActiveStore) { insertAllFrom(TR, *Hypothesis, Val); // TODO have insertall dependence on this @@ -2212,15 +2353,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, new mlir::enzyme::ActivityAnalyzer(*DownHypothesis, DOWN)); DownHypothesis2->ConstantValues.insert(TmpOrig); if (DownHypothesis2->isValueActivelyStoredOrReturned(TR, TmpOrig)) { - // if (EnzymePrintActivity) - // llvm::errs() << " active from ivasor: " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " active from ivasor: " << TmpOrig << "\n"; ActiveDown = true; } } else { // unknown origin that could've been stored/returned/etc - // if (EnzymePrintActivity) - // llvm::errs() << " active from unknown origin: " << *TmpOrig << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " active from unknown origin: " << TmpOrig << "\n"; ActiveDown = true; } } @@ -2257,13 +2397,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If we go to an active return and only load it, however, that doesnt // transfer derivatives and we can say this memory is inactive - // if (EnzymePrintActivity) - // llvm::errs() << " @@MEMSEARCH" << (int)directions << ">" << *Val - // << " potentiallyActiveLoad=" << potentiallyActiveLoad - // << " potentialStore=" << potentialStore - // << " ActiveUp=" << ActiveUp << " ActiveDown=" << - // ActiveDown - // << " ActiveMemory=" << ActiveMemory << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " @@MEMSEARCH" << (int)directions << ">" << Val + << " potentiallyActiveLoad=" << potentiallyActiveLoad + << " potentialStore=" << potentialStore + << " ActiveUp=" << ActiveUp << " ActiveDown=" << ActiveDown + << " ActiveMemory=" << ActiveMemory << "\n"; if (ActiveMemory) { ActiveValues.insert(Val); @@ -2293,39 +2432,21 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // this value is inactive, we are inactive Since we won't look at uses to // prove, we can inductively assume this is inactive if (directions & UP) { - if (directions == UP && !Val.isa()) { - if (isValueInactiveFromOrigin(TR, Val)) { - InsertConstantValue(TR, Val); - return true; - } else if (Operation *op = Val.getDefiningOp()) { - if (directions == (UP | DOWN)) { - for (Value operand : op->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - for (Value result : op->getResults()) { - ReEvaluateValueIfInactiveValue[operand].insert(result); - } - } - } - } - } + UpHypothesis = std::shared_ptr( + new mlir::enzyme::ActivityAnalyzer(*this, UP)); + UpHypothesis->ConstantValues.insert(Val); + SmallPtrSet toredo; + if (UpHypothesis->isValueInactiveFromOrigin(TR, Val, &toredo)) { + insertConstantsFrom(TR, *UpHypothesis); + InsertConstantValue(TR, Val); + if (EnzymePrintActivity) + llvm::errs() << " Value constant from origin [" << (int)directions + << "]" << Val << "\n"; + return true; } else { - UpHypothesis = std::shared_ptr( - new mlir::enzyme::ActivityAnalyzer(*this, UP)); - UpHypothesis->ConstantValues.insert(Val); - if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } else if (Operation *op = Val.getDefiningOp()) { - if (directions == (UP | DOWN)) { - for (Value operand : op->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - for (Value result : op->getResults()) { - ReEvaluateValueIfInactiveValue[operand].insert(result); - } - } - } - } + for (Value result : toredo) { + if (result != Val) + ReEvaluateValueIfInactiveValue[result].insert(Val); } } } @@ -2335,164 +2456,54 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If all users are inactive, this is therefore inactive. // Since we won't look at origins to prove, we can inductively assume this // is inactive - - // As an optimization if we are going down already - // and we won't use ourselves (done by PHI's), we - // dont need to inductively assume we're true - // and can instead use this object! - if (directions == DOWN && !Val.isa()) { - if (isValueInactiveFromUsers(TR, Val, UseActivity::None)) { - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } - } else { - auto DownHypothesis = std::shared_ptr( - new mlir::enzyme::ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(Val); - if (DownHypothesis->isValueInactiveFromUsers(TR, Val, - UseActivity::None)) { - insertConstantsFrom(TR, *DownHypothesis); - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } + auto DownHypothesis = std::shared_ptr( + new mlir::enzyme::ActivityAnalyzer(*this, DOWN)); + DownHypothesis->ConstantValues.insert(Val); + if (DownHypothesis->isValueInactiveFromUsers(TR, Val, UseActivity::None)) { + insertConstantsFrom(TR, *DownHypothesis); + if (UpHypothesis) + insertConstantsFrom(TR, *UpHypothesis); + InsertConstantValue(TR, Val); + return true; } } - // if (EnzymePrintActivity) - // llvm::errs() << " Value nonconstant (couldn't disprove)[" << - // (int)directions - // << "]" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Value nonconstant (couldn't disprove)[" << (int)directions + << "]" << Val << "\n"; ActiveValues.insert(Val); return false; } /// Is the value guaranteed to be inactive because of how it's produced. bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromOrigin( - MTypeResults const &TR, Value val) { + MTypeResults const &TR, Value val, SmallPtrSetImpl *inactArg) { // Must be an analyzer only searching up assert(directions == UP); - // TODO: use getPotentialIncomingValues here to avoid duplciation. - if (auto arg = val.dyn_cast()) { - if (arg.getOwner()->isEntryBlock()) { - Operation *parent = arg.getOwner()->getParentOp(); - Region *parentRegion = arg.getOwner()->getParent(); - SetVector potentialSources; - // Use region interface to find the values flowing into the entry block. - if (auto iface = dyn_cast(parent)) { - auto isRegionSucessorOf = [arg](RegionBranchOpInterface iface, - Region *region, - RegionBranchPoint predecessor, - SetVector &potentialSources) { - SmallVector successors; - iface.getSuccessorRegions(predecessor, successors); - for (const RegionSuccessor &successor : successors) { - if (successor.getSuccessor() != region) - continue; - - unsigned operandOffset = static_cast(-1); - for (const auto &en : - llvm::enumerate(successor.getSuccessorInputs())) { - if (en.value() != arg) - continue; - operandOffset = en.index(); - } - assert(operandOffset != static_cast(-1) && - "could not locate the position of the argument in the " - "successor input list"); - - // Find the values that are forwarded to entry block arguments of - // the current region. - if (predecessor.isParent()) { - // XXX: this assumes a contiguous slice of operands is mapped 1-1 - // without swaps to a contiguous slice of entry block arguments. - assert(iface.getEntrySuccessorOperands(region).size() == - successor.getSuccessorInputs().size()); - potentialSources.insert( - iface.getEntrySuccessorOperands(region)[operandOffset]); - } else { - // Find all block terminators in the predecessor region that - // may be branching to this region, and get the operands they - // forward. - for (Block &block : *predecessor.getRegionOrNull()) { - // TODO: MLIR block without terminator - if (auto terminator = - dyn_cast( - block.getTerminator())) { - // XXX: this assumes a contiguous slice of operands is mapped - // 1-1 without swaps to a contiguous slice of entry block - // arguments. - assert(terminator.getSuccessorOperands(region).size() == - successor.getSuccessorInputs().size()); - potentialSources.insert( - terminator.getSuccessorOperands(region)[operandOffset]); - } else { - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); - } - } - } - } - }; - - // Find all possible source regions for the current region. - isRegionSucessorOf(iface, parentRegion, RegionBranchPoint::parent(), - potentialSources); - for (Region ®ion : parent->getRegions()) - isRegionSucessorOf(iface, parentRegion, region, potentialSources); - - } else { - // Conservatively assume any op operand and any terminator operand of - // any region can flow into any block argument. - for (Region ®ion : parent->getRegions()) { - for (Block &block : region) { - // TODO: MLIR blocks without terminator? - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); - } - } - } - - return llvm::all_of(potentialSources, [&](Value value) { - return isConstantValue(TR, value); - }); - } - - // Look at values flowing into block arguments. - for (Block *predecessor : arg.getOwner()->getPredecessors()) { - Operation *terminator = predecessor->getTerminator(); - if (auto iface = dyn_cast(terminator)) { - for (const auto &en : llvm::enumerate(predecessor->getSuccessors())) { - if (en.value() != arg.getOwner()) - continue; - - Value inflow = iface.getSuccessorOperands(en.index()) - .getForwardedOperands()[arg.getArgNumber()]; - if (!isConstantValue(TR, inflow)) - return false; - } - } else { - for (Value operand : terminator->getOperands()) { - if (!isConstantValue(TR, operand)) - return false; + for (auto v : getPotentialIncomingValues(arg)) { + if (!isConstantValue(TR, v)) { + if (EnzymePrintActivity) { + llvm::errs() << " blockarg: " << arg + << " may be active due to inflow from " << v << "\n"; } + if (inactArg) + inactArg->insert(v); + return false; } } - return true; } return isOperationInactiveFromOrigin(TR, val.getDefiningOp(), - val.cast().getResultNumber()); + val.cast().getResultNumber(), + inactArg); } bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( - MTypeResults const &TR, Operation *op, std::optional resultNo) { + MTypeResults const &TR, Operation *op, std::optional resultNo, + SmallPtrSetImpl *inactArg) { // Must be an analyzer only searching up assert(directions == UP); @@ -2510,19 +2521,22 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( } } - // if (EnzymePrintActivity) - // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *op << "\n"; if (auto store = dyn_cast(op)) { if (isConstantValue(TR, store.getValue()) || isConstantValue(TR, store.getAddr())) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as store operand is inactive - // " - // << *inst << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as store operand is inactive" + << *op << "\n"; return true; } + if (inactArg) { + inactArg->insert(store.getValue()); + inactArg->insert(store.getAddr()); + } + return false; } if (isa(op)) { @@ -2530,11 +2544,15 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( // values and thus the store is inactive if (isConstantValue(TR, op->getOperand(0)) || isConstantValue(TR, op->getOperand(1))) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as memtransfer " << *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as memtransfer " << *op << "\n"; return true; } + if (inactArg) { + inactArg->insert(op->getOperand(0)); + inactArg->insert(op->getOperand(1)); + } + return false; } if (auto call = dyn_cast(op)) { @@ -2576,9 +2594,9 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (KnownInactiveFunctions.count(funcName.str()) || MPIInactiveCommAllocators.find(funcName.str()) != MPIInactiveCommAllocators.end()) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions - // << ") up-knowninactivecall " << *inst << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions + << ") up-knowninactivecall " << *op << "\n"; return true; } @@ -2600,9 +2618,9 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( Value callVal = call.getCallableForCallee().dyn_cast(); if (callVal) if (isConstantValue(TR, callVal)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-constfn " - // << *inst << " - " << *callVal << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-constfn " + << *op << " - " << callVal << "\n"; return true; } } @@ -2610,12 +2628,14 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (auto gep = dyn_cast(op)) { // A gep's only args that could make it active is the pointer operand if (isConstantValue(TR, gep.getBase())) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-gep " << - // *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-gep " << *op + << "\n"; return true; } + if (inactArg) { + inactArg->insert(gep.getBase()); + } return false; } @@ -2676,97 +2696,79 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (isConstantValue(TR, si.getTrueValue()) && isConstantValue(TR, si.getFalseValue())) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-sel:" << - // *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-sel:" << *op + << "\n"; return true; } + if (inactArg) { + inactArg->insert(si.getTrueValue()); + inactArg->insert(si.getFalseValue()); + } return false; } - { - bool seenuse = false; - //! TODO does not consider reading from global memory that is active and not - //! an argument + if (!resultNo) { for (Value a : op->getOperands()) { bool hypval = isConstantValue(TR, a); if (!hypval) { - // if (EnzymePrintActivity) - // llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " - // << *inst << " op " << *a << "\n"; - seenuse = true; - break; - } - } - if (!resultNo) { - // Conservatively check all top-level operations nested in the region, - // there is recursion there. - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // XXX: We think that we don't need to check block arguments here - // because values flow into them either from operands of the parent op - // or from the op itself. - if (llvm::any_of(block, [&](Operation &nested) { - // No need to check the results, even if they may be active, - // because in absence of resultNo, we are checking for the - // entire op being inactive not individual values. - // - // // The loop _operation_ is inactive, but the result is, just - // // like the GEP inside it. - // %r = scf.for %i.. { - // // The GEP operation is not active, but the result is. - // %active_r = llvm.gep ... %active_operand - // scf.yield %active_r - // } - return !isConstantOperation(TR, &nested); - })) { - seenuse = true; - break; - } + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " + << *op << " op " << a << "\n"; + if (inactArg) { + inactArg->insert(a); } - if (seenuse) - break; + return false; } - } else { - SetVector potentialSources; - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // TODO: MLIR blocks without terminator? - if (auto iface = dyn_cast( - block.getTerminator())) { - // TODO: the interface may also tell us which regions are allowed to - // yield parent op results, and which only branch to other regions. - auto successorOperands = llvm::to_vector( - iface.getSuccessorOperands(RegionBranchPoint::parent())); - // TODO: understand/document the assumption of how operands flow. - assert(successorOperands.size() == op->getNumResults() && - "expected all results to be populated with yielded " - "terminator operands"); - potentialSources.insert(successorOperands[*resultNo]); - } else { - // assume all terminator operands potentially flow into op results - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); + } + // Conservatively check all top-level operations nested in the region, + // there is recursion there. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + // XXX: We think that we don't need to check block arguments here + // because values flow into them either from operands of the parent op + // or from the op itself. + for (Operation &nested : block) { + // No need to check the results, even if they may be active, + // because in absence of resultNo, we are checking for the + // entire op being inactive not individual values. + // + // // The loop _operation_ is inactive, but the result is, just + // // like the GEP inside it. + // %r = scf.for %i.. { + // // The GEP operation is not active, but the result is. + // %active_r = llvm.gep ... %active_operand + // scf.yield %active_r + // } + if (!isConstantOperation(TR, &nested)) { + // TODO set inactArg here, except with constant operand. + // assert(!inactArg); + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions + << ") up-inst-op " << *op << " sub-op " << nested + << "\n"; + return false; } } } - if (llvm::any_of(potentialSources, [&](Value value) { - return !isConstantValue(TR, value); - })) { - seenuse = true; - } } - - if (!seenuse) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-inst:" << - // *inst - // << "\n"; - return true; + } else { + for (auto value : getPotentialIncomingValues(op->getResult(*resultNo))) { + if (!isConstantValue(TR, value)) { + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " + << *op << " value " << value << "\n"; + if (inactArg) + inactArg->insert(value); + return false; + } } - return false; } + + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-inst:" << *op + << "\n"; + return true; } /// Is the value free of any active uses @@ -2778,9 +2780,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // To ensure we can call down - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << " UA=" << (int)PUA << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << " UA=" << (int)PUA << "\n"; bool seenuse = false; // user, predecessor @@ -2834,9 +2836,8 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( } } - // if (EnzymePrintActivity) - // llvm::errs() << " considering use of " << *val << " - " << *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " considering use of " << val << " - " << *a << "\n"; // Only ignore stores to the operand, not storing the operand // somewhere @@ -2915,25 +2916,23 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // vtodo.push_back(TmpOrig_2); // continue; // } - // if (EnzymePrintActivity) - // llvm::errs() << " -- cannot continuing indirect store from - // " - // << *val << " due to " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- cannot continuing indirect store from" + << val << " due to " << TmpOrig << "\n"; shouldContinue = false; break; } if (shouldContinue) { - // if (EnzymePrintActivity) - // llvm::errs() << " -- continuing indirect store from " << - // *val - // << " into:\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- continuing indirect store from " << val + << " into:\n"; done.insert(std::make_tuple(SI.getOperation(), SI.getValue(), UA)); for (Value TmpOrig : newAllocaSet) { for (Operation *a : TmpOrig.getUsers()) { todo.push_back(std::make_tuple(a, TmpOrig, UA)); - // if (EnzymePrintActivity) - // llvm::errs() << " ** " << *a << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " ** " << *a << "\n"; } AllocaSet.insert(TmpOrig); shouldContinue = true; @@ -2993,10 +2992,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( break; } if (shouldContinue) { - // if (EnzymePrintActivity) - // llvm::errs() << " -- continuing indirect store2 from " << - // *val - // << " via " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- continuing indirect store2 from " << val + << " via " << TmpOrig << "\n"; continue; } } @@ -3034,10 +3032,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // } if (isa(a)) { - // if (EnzymePrintActivity) - // llvm::errs() << "found constant(" << (int)directions - // << ") allocainst use:" << *val << " user " << *a << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << "found constant(" << (int)directions + << ") allocainst use:" << val << " user " << *a << "\n"; continue; } @@ -3066,12 +3063,21 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // continue; // This use is only active if specified - if (isa(a)) { - if (ActiveReturns == DIFFE_TYPE::CONSTANT && - UA != UseActivity::AllStores) { + if (UA != UseActivity::AllStores) { + if (auto termUsers = getPotentialTerminatorUsers(a, parent)) { + for (auto post : *termUsers) { + for (Operation *postUser : post.getUsers()) { + todo.push_back(std::make_tuple(postUser, post, UA)); + } + } continue; - } else { - return false; + } + if (isFunctionReturn(a)) { + if (ActiveReturns == DIFFE_TYPE::CONSTANT) { + continue; + } else { + return false; + } } } @@ -3166,10 +3172,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (auto call = dyn_cast(a)) { bool ConstantArg = isFunctionArgumentConstant(call, parent); if (ConstantArg && UA != UseActivity::AllStores) { - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant callinst use:" << *val - // << " user " << *call << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant callinst use:" << val + << " user " << *call << "\n"; + } continue; } @@ -3190,10 +3196,8 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (operand.getDefiningOp()) { bool legal = true; - for (unsigned i = 0; i < call.getArgOperands().size() + 1; ++i) { - // FIXME: this is based on an assumption that the callee operand - // precedes arg operands. - Value a = call->getOperand(i); + for (unsigned i = 0; i < call.getArgOperands().size(); ++i) { + Value a = call.getArgOperands()[i]; // FIXME: yet another ingrained assumption that integers cannot be // active. @@ -3267,11 +3271,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (Operation *I = a) { if (notForAnalysis.count(I->getBlock())) { // TODO(PR #904): replace the "EnzymePrintActivity" flag with LLVM_DEBUG - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant unreachable inst use:" << - // *val - // << " user " << *I << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant unreachable inst use:" << val + << " user " << *I << "\n"; + } continue; } if (UA != UseActivity::AllStores && ConstantOperations.count(I)) { @@ -3280,11 +3283,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( .isa() || ConstantValues.count(val); })) { - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant inst use:" << *val << " - // user " - // << *I << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant inst use:" << val << " user " + << *I << "\n"; + } continue; } UseActivity NU = UA; @@ -3359,17 +3361,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( *FoundInst = I; } - // if (EnzymePrintActivity) - // llvm::errs() << "Value nonconstant inst (uses):" << *val << " user " << - // *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "Value nonconstant inst (uses):" << val << " user " << *a + << "\n"; seenuse = true; break; } - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << "\n"; return !seenuse; } @@ -3386,10 +3387,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( return StoredOrReturnedCache[key]; } - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << "\n"; StoredOrReturnedCache[key] = false; @@ -3415,14 +3416,21 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( continue; } - if (isa(a)) { + if (auto termUsers = getPotentialTerminatorUsers(a, val)) { + for (auto post : *termUsers) + if (isValueActivelyStoredOrReturned(TR, post, outside)) { + return StoredOrReturnedCache[key] = true; + } + return false; + } + if (isFunctionReturn(a)) { if (ActiveReturns == DIFFE_TYPE::CONSTANT) continue; - // if (EnzymePrintActivity) - // llvm::errs() << " " - // << " active from-ret>" << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " + << " active from-ret>" << val << "\n"; StoredOrReturnedCache[key] = true; return true; } @@ -3445,11 +3453,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // Storing into active value, return true if (!isConstantValue(TR, SI.getValue())) { StoredOrReturnedCache[key] = true; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << " store into=" << *SI << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << " store into=" << *SI << "\n"; return true; } } @@ -3458,11 +3466,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // Storing into active memory, return true if (!isConstantValue(TR, SI.getAddr())) { StoredOrReturnedCache[key] = true; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << " store=" << *SI - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << " store=" << *SI + << "\n"; return true; } continue; @@ -3513,18 +3521,17 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // it is written to active memory // TODO handle more memory instructions above to be less conservative - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << " - use=" << *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << " - use=" << *a << "\n"; return StoredOrReturnedCache[key] = true; } - // if (EnzymePrintActivity) - // llvm::errs() << " " - // << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " + << val << "\n"; return false; } @@ -3540,9 +3547,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantOperation( if (!ActiveValues.count(toeval)) continue; ActiveValues.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of val " << *toeval - // << " due to inst " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of val " << toeval + << " due to inst " << *I << "\n"; isConstantValue(TR, toeval); } } @@ -3558,9 +3565,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantValue(MTypeResults const &TR, if (!ActiveValues.count(toeval)) continue; ActiveValues.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of val " << *toeval - // << " due to value " << *V << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of val " << toeval + << " due to value " << V << "\n"; isConstantValue(TR, toeval); } } @@ -3572,9 +3579,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantValue(MTypeResults const &TR, if (!ActiveOperations.count(toeval)) continue; ActiveOperations.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of inst " << *toeval - // << " due to value " << *V << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of inst " << *toeval + << " due to value " << V << "\n"; isConstantOperation(TR, toeval); } } diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h index bff41155fc3b..c73df2025883 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h @@ -143,12 +143,18 @@ class ActivityAnalyzer { bool isFunctionArgumentConstant(mlir::CallOpInterface CI, Value val); /// Is the value guaranteed to be inactive because of how it's produced. - bool isValueInactiveFromOrigin(MTypeResults const &TR, Value val); + /// If active and inactArg is non-null, store any values which may allow this + /// to succeed in the future + bool isValueInactiveFromOrigin( + MTypeResults const &TR, Value val, + llvm::SmallPtrSetImpl *inactArg = nullptr); + /// Is the operation guaranteed to be inactive because of how its operands are /// produced. bool isOperationInactiveFromOrigin( MTypeResults const &TR, Operation *op, - std::optional resultNo = std::nullopt); + std::optional resultNo = std::nullopt, + llvm::SmallPtrSetImpl *inactArg = nullptr); public: enum class UseActivity { diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 613e0ef2ad7e..a15b3283e10d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -57,7 +57,13 @@ class TensorTypeInterface public: Value createNullValue(Type self, OpBuilder &builder, Location loc) const { auto tenType = self.cast(); - auto attr = DenseElementsAttr::get(tenType, 0); + auto ET = tenType.getElementType(); + size_t num = 1; + for (auto sz : tenType.getShape()) + num *= sz; + APFloat apvalue(ET.cast().getFloatSemantics(), 0); + SmallVector supportedValues(num, apvalue); + auto attr = DenseElementsAttr::get(tenType, supportedValues); return builder.create(loc, tenType, attr); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index c18a7bd0e921..43b4757bdf53 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -217,32 +217,11 @@ void mlir::enzyme::detail::regionTerminatorForwardHandler( // Assuming shadows following the originals are fine. // TODO: consider extending to have a ShadowableTerminatorOpInterface Operation *replTerminator = gutils->getNewFromOriginal(origTerminator); - Operation *newTerminator = builder.clone(*replTerminator); - newTerminator->setOperands(newOperands); - gutils->erase(replTerminator); + replTerminator->setOperands(newOperands); } LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( Operation *op, OpBuilder &builder, MGradientUtils *gutils) { - // For all active results, add shadow types. - // For now, assuming all results are relevant. - Operation *newOp = gutils->getNewFromOriginal(op); - SmallVector newOpResultTypes; - newOpResultTypes.reserve(op->getNumResults() * 2); - for (Value result : op->getResults()) { - // TODO only if used (can we DCE the primal after having done the - // derivative). - newOpResultTypes.push_back(result.getType()); - if (gutils->isConstantValue(result)) - continue; - auto typeIface = dyn_cast(result.getType()); - if (!typeIface) { - op->emitError() << " AutoDiffTypeInterface not implemented for " - << result.getType() << "\n"; - return failure(); - } - newOpResultTypes.push_back(typeIface.getShadowType()); - } // For all operands that are forwarded to the body, if they are active, also // add the shadow as operand. @@ -253,28 +232,72 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( return failure(); } - SmallVector successors; // TODO: we may need to record, for every successor, which of its inputs // need a shadow to recreate the body correctly. llvm::SmallDenseSet operandPositionsToShadow; + llvm::SmallDenseSet resultPositionsToShadow; + + SmallVector entrySuccessors; regionBranchOp.getEntrySuccessorRegions( - SmallVector(op->getNumOperands(), Attribute()), successors); - for (const RegionSuccessor &successor : successors) { - if (!successor.isParent() && successor.getSuccessor()->empty()) - continue; + SmallVector(op->getNumOperands(), Attribute()), + entrySuccessors); + + for (const RegionSuccessor &successor : entrySuccessors) { OperandRange operandRange = regionBranchOp.getEntrySuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? op->getResults() + : successor.getSuccessorInputs(); + // Need to know which of the arguments are being forwarded to from // operands. for (auto &&[i, regionValue, operand] : - llvm::enumerate(successor.getSuccessorInputs(), operandRange)) { + llvm::enumerate(targetValues, operandRange)) { if (gutils->isConstantValue(regionValue)) continue; operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); + if (successor.isParent()) + resultPositionsToShadow.insert(i); } } + + for (auto res : op->getResults()) + if (!gutils->isConstantValue(res)) + resultPositionsToShadow.insert(res.getResultNumber()); + + return controlFlowForwardHandler( + op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow); +} + +LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils, + const llvm::SmallDenseSet &operandPositionsToShadow, + const llvm::SmallDenseSet &resultPositionsToShadow) { + // For all active results, add shadow types. + // For now, assuming all results are relevant. + Operation *newOp = gutils->getNewFromOriginal(op); + SmallVector newOpResultTypes; + newOpResultTypes.reserve(op->getNumResults() * 2); + for (auto result : op->getResults()) { + // TODO only if used (can we DCE the primal after having done the + // derivative). + newOpResultTypes.push_back(result.getType()); + if (!gutils->isConstantValue(result)) { + assert(resultPositionsToShadow.count(result.getResultNumber())); + } + if (!resultPositionsToShadow.count(result.getResultNumber())) + continue; + auto typeIface = dyn_cast(result.getType()); + if (!typeIface) { + op->emitError() << " AutoDiffTypeInterface not implemented for " + << result.getType() << "\n"; + return failure(); + } + newOpResultTypes.push_back(typeIface.getShadowType()); + } + SmallVector newOperands; newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size()); for (OpOperand &operand : op->getOpOperands()) { @@ -297,6 +320,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( } Operation *replacement = iface.createWithShadows( builder, gutils, op, newOperands, newOpResultTypes); + assert(replacement->getNumResults() == newOpResultTypes.size()); for (auto &&[region, replacementRegion] : llvm::zip(newOp->getRegions(), replacement->getRegions())) { replacementRegion.takeBody(region); diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 974b888599d5..a7ec6b179986 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -15,10 +15,13 @@ #include "Interfaces/AutoDiffOpInterface.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" + namespace mlir { class DialectRegistry; class Operation; class OpBuilder; +class RegionSuccessor; namespace enzyme { class MGradientUtils; @@ -27,9 +30,15 @@ class MGradientUtilsReverse; namespace detail { // Non-template implementation of // AutoDiffUsingControlFlow::createForwardModeTangent. + LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); +LogicalResult controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils, + const llvm::SmallDenseSet &operandPositionsToShadow, + const llvm::SmallDenseSet &resultPositionsToShadow); + // Implements forward-mode differentiation of branching operations. // Assumes that successive shadows are legal void branchingForwardHandler(Operation *op, OpBuilder &builder, diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index b9cbfe3e6913..39de939bc27a 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -165,6 +165,11 @@ Create reverse mode adjoint for an operation. */ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) { + if (llvm::all_of(op->getResults(), + [gutils](Value v) { return gutils->isConstantValue(v); }) && + gutils->isConstantInstruction(op)) { + return; + } if (auto ifaceOp = dyn_cast(op)) { SmallVector caches = ifaceOp.cacheValues(gutils); ifaceOp.createReverseModeAdjoint(builder, gutils, caches); @@ -175,6 +180,7 @@ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, gutils->clearValue(result, builder); } } + op->emitError() << "could not compute the adjoint for this operation " << *op; } void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 45473c75921f..d83e2da59a17 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -115,14 +115,14 @@ mlir::enzyme::MGradientUtils::getNewFromOriginal(mlir::Block *originst) const { Operation * mlir::enzyme::MGradientUtils::getNewFromOriginal(Operation *originst) const { + assert(originst); auto found = originalToNewFnOps.find(originst); if (found == originalToNewFnOps.end()) { llvm::errs() << oldFunc << "\n"; llvm::errs() << newFunc << "\n"; for (auto &pair : originalToNewFnOps) { llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n"; - // llvm::errs() << " map[" << pair.first << "] = " << pair.second << " - // -- " << *pair.first << " " << *pair.second << "\n"; + llvm::errs() << " map[" << *pair.first << "] = " << *pair.second << "\n"; } llvm::errs() << originst << " - " << *originst << "\n"; llvm_unreachable("Could not get new op from original"); @@ -154,7 +154,12 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, if (isConstantValue(v)) { if (auto iface = v.getType().dyn_cast()) { OpBuilder::InsertionGuard guard(Builder2); - Builder2.setInsertionPoint(getNewFromOriginal(v.getDefiningOp())); + if (auto op = v.getDefiningOp()) + Builder2.setInsertionPoint(getNewFromOriginal(op)); + else { + auto ba = cast(v); + Builder2.setInsertionPointToStart(getNewFromOriginal(ba.getOwner())); + } Value dv = iface.createNullValue(Builder2, v.getLoc()); invertedPointers.map(v, dv); return dv; diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 432b0938f763..d2a04fc37412 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -101,6 +101,13 @@ def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> { /*default=*/"false", /*description=*/"Annotate every operation and value with its activity" >, + Option< + /*C++ variable name=*/"dataflow", + /*CLI argument=*/"dataflow", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Whether to use the new Dataflow activity analysis" + >, Option< /*C++ variable name=*/"inactiveArgs", /*CLI argument=*/"inactive-args", diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index fb3ef000a0c1..ac88d0fc86fa 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -10,8 +10,10 @@ // analysis. // //===----------------------------------------------------------------------===// +#include "Analysis/ActivityAnalysis.h" #include "Analysis/DataFlowActivityAnalysis.h" #include "Dialect/Ops.h" +#include "Interfaces/EnzymeLogic.h" #include "Passes/PassDetails.h" #include "Passes/Passes.h" @@ -116,10 +118,79 @@ struct PrintActivityAnalysisPass } } + void runActivityAnalysis(bool dataflow, FunctionOpInterface callee, + ArrayRef argActivities, + ArrayRef resultActivities, + bool print, bool verbose, bool annotate) { + if (dataflow) { + enzyme::runDataFlowActivityAnalysis(callee, argActivities, + /*print=*/true, verbose, annotate); + } else { + + SmallPtrSet blocksNotForAnalysis; + + mlir::enzyme::MTypeResults TR; // TODO + SmallPtrSet constant_values; + SmallPtrSet activevals_; + for (auto &&[arg, act] : + llvm::zip(callee.getFunctionBody().getArguments(), argActivities)) { + if (act == enzyme::Activity::enzyme_const) + constant_values.insert(arg); + else + activevals_.insert(arg); + } + auto ReturnActivity = DIFFE_TYPE::CONSTANT; + for (auto act : resultActivities) + if (act != enzyme::Activity::enzyme_const) + ReturnActivity = DIFFE_TYPE::DUP_ARG; + + enzyme::ActivityAnalyzer activityAnalyzer( + blocksNotForAnalysis, constant_values, activevals_, ReturnActivity); + + callee.walk([&](Operation *op) { + + }); + MLIRContext *ctx = callee.getContext(); + callee.walk([&](Operation *op) { + if (print) + llvm::outs() << " Operation: " << *op << "\n"; + for (auto ® : op->getRegions()) { + for (auto &blk : reg.getBlocks()) { + for (auto &arg : blk.getArguments()) { + bool icv = activityAnalyzer.isConstantValue(TR, arg); + if (annotate) + op->setAttr("enzyme.arg_icv" + + std::to_string(arg.getArgNumber()), + BoolAttr::get(ctx, icv)); + if (print) + llvm::outs() << " arg: " << arg << " icv=" << icv << "\n"; + } + } + } + + bool ici = activityAnalyzer.isConstantOperation(TR, op); + if (annotate) + op->setAttr("enzyme.ici", BoolAttr::get(ctx, ici)); + if (print) + llvm::outs() << " op ici=" << ici << "\n"; + + for (auto res : op->getResults()) { + bool icv = activityAnalyzer.isConstantValue(TR, res); + if (annotate) + op->setAttr("enzyme.res_icv" + + std::to_string(res.getResultNumber()), + BoolAttr::get(ctx, icv)); + if (print) + llvm::outs() << " res: " << res << " icv=" << icv << "\n"; + } + }); + } + } + void runOnOperation() override { auto moduleOp = cast(getOperation()); - if (annotate) { + if (annotate && dataflow) { // Infer the activity attributes from the __enzyme_autodiff call Operation *autodiff_decl = moduleOp.lookupSymbol("__enzyme_autodiff"); if (!autodiff_decl) @@ -148,8 +219,8 @@ struct PrintActivityAnalysisPass // supplied annotation. First argument is the callee inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); } return; } @@ -163,8 +234,8 @@ struct PrintActivityAnalysisPass resultActivities{callee.getNumResults()}; initializeArgAndResActivities(callee, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); }); return; } @@ -186,8 +257,8 @@ struct PrintActivityAnalysisPass resultActivities{callee.getNumResults()}; initializeArgAndResActivities(callee, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); } } }; diff --git a/enzyme/test/MLIR/ActivityAnalysis/region.mlir b/enzyme/test/MLIR/ActivityAnalysis/region.mlir new file mode 100644 index 000000000000..4526e1325661 --- /dev/null +++ b/enzyme/test/MLIR/ActivityAnalysis/region.mlir @@ -0,0 +1,27 @@ +// RUN: %eopt --pass-pipeline="builtin.module(print-activity-analysis{dataflow=false annotate=true})" %s --split-input-file 2>&1 | FileCheck %s + +// A function that contains active and inactive region dataflow + +func.func @region(%x: f64) -> (f64, f64) { + %f0 = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + %r0:2 = scf.for %arg12 = %c0 to %c10 step %c1 iter_args(%arg13 = %f0, %arg14 = %f0) -> (f64, f64) { + %m = arith.addf %arg13, %x : f64 + scf.yield %m, %arg14 : f64, f64 + } + return %r0#0, %r0#1 : f64, f64 +} + +// CHECK: func.func @region(%arg0: f64) -> (f64, f64) attributes {enzyme.arg_icv0 = false, enzyme.ici = false} { +// CHECK-NEXT: %cst = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 0.000000e+00 : f64 +// CHECK-NEXT: %c0 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 0 : index +// CHECK-NEXT: %c10 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 10 : index +// CHECK-NEXT: %c1 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 1 : index +// CHECK-NEXT: %0:2 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst, %arg3 = %cst) -> (f64, f64) { +// CHECK-NEXT: %1 = arith.addf %arg2, %arg0 {enzyme.ici = false, enzyme.res_icv0 = false} : f64 +// CHECK-NEXT: scf.yield {enzyme.ici = true} %1, %arg3 : f64, f64 +// CHECK-NEXT: } {enzyme.arg_icv0 = true, enzyme.arg_icv1 = false, enzyme.arg_icv2 = true, enzyme.ici = false, enzyme.res_icv0 = false, enzyme.res_icv1 = true} +// CHECK-NEXT: return {enzyme.ici = true} %0#0, %0#1 : f64, f64 +// CHECK-NEXT: } From 94aa6edcd5f1a828f770a2ad1c1c50be60d8f8a1 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 17 Feb 2024 21:36:56 -0500 Subject: [PATCH 069/106] Disable MLIR print activity --- enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 30fbfc8ef345..1415559b3f6b 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -248,7 +248,7 @@ static Operation *getFunctionFromCall(CallOpInterface iface) { return SymbolTable::lookupNearestSymbolFrom(iface.getOperation(), symbol); } -constexpr bool EnzymePrintActivity = true; +constexpr bool EnzymePrintActivity = false; /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode From c1f59a9838a73fb15f278fdcefa2f78ac7112a6b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 17 Feb 2024 22:41:19 -0500 Subject: [PATCH 070/106] MLIR fix forward terminator interface (#1733) --- .../CoreDialectsAutoDiffImplementations.cpp | 46 +++++++++++-------- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 6 +-- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 43b4757bdf53..025946ef9651 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -186,29 +186,39 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler( void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { - auto termIface = cast(origTerminator); - auto parentOp = termIface->getParentOp(); - - SmallVector successors; - termIface.getSuccessorRegions( - SmallVector(termIface->getNumOperands(), Attribute()), - successors); + auto parentOp = origTerminator->getParentOp(); llvm::SmallDenseSet operandsToShadow; - for (auto &successor : successors) { - OperandRange operandRange = termIface.getSuccessorOperands(successor); - ValueRange targetValues = successor.isParent() - ? parentOp->getResults() - : successor.getSuccessorInputs(); - assert(operandRange.size() == targetValues.size()); - for (auto &&[i, target] : llvm::enumerate(targetValues)) { - if (!gutils->isConstantValue(target)) - operandsToShadow.insert(operandRange.getBeginOperandIndex() + i); + if (auto termIface = + dyn_cast(origTerminator)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(origTerminator->getNumOperands(), Attribute()), + successors); + + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[i, target] : llvm::enumerate(targetValues)) { + if (!gutils->isConstantValue(target)) + operandsToShadow.insert(operandRange.getBeginOperandIndex() + i); + } + } + } else { + assert(parentOp->getNumResults() == origTerminator->getNumOperands()); + for (auto res : parentOp->getResults()) { + if (!gutils->isConstantValue(res)) + operandsToShadow.insert(res.getResultNumber()); } } + SmallVector newOperands; - newOperands.reserve(termIface->getNumOperands() + operandsToShadow.size()); - for (OpOperand &operand : termIface->getOpOperands()) { + newOperands.reserve(origTerminator->getNumOperands() + + operandsToShadow.size()); + for (OpOperand &operand : origTerminator->getOpOperands()) { newOperands.push_back(gutils->getNewFromOriginal(operand.get())); if (operandsToShadow.contains(operand.getOperandNumber())) newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index d83e2da59a17..4598394b9e44 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -306,11 +306,7 @@ void mlir::enzyme::MGradientUtils::forceAugmentedReturns() { LogicalResult MGradientUtils::visitChild(Operation *op) { if (mode == DerivativeMode::ForwardMode) { - // In absence of a proper activity analysis, approximate it by treating any - // side effect-free operation producing constants as inactive. - // if (auto iface = dyn_cast(op)) { - if (!isa(op) && - !isa(op) && + if ((op->getBlock()->getTerminator() != op) && llvm::all_of(op->getResults(), [this](Value v) { return isConstantValue(v); }) && /*iface.hasNoEffect()*/ activityAnalyzer->isConstantOperation(TR, op)) { From b675db757a1df1646b9518042859cef41e477602 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 17 Feb 2024 22:43:36 -0500 Subject: [PATCH 071/106] MLIR add constantfp to common.td (#1734) --- enzyme/Enzyme/MLIR/Implementations/Common.td | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index f558d1225a76..10f9c1532014 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -62,6 +62,12 @@ class Inst : Operation : Operation { + string value = val; + string dialect = dialect_; + string opName = op_; +} + class ArithInst : Inst; class MathInst : Inst; From d7c94ddd6ac7bb026a741aed030734ba8299bbca Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 16:08:08 -0500 Subject: [PATCH 072/106] MLIR handle dual tablegen (#1735) * MLIR handle dual tablegen * fix * extend constantfp --- .../MLIR/Implementations/ArithDerivatives.td | 4 +-- enzyme/Enzyme/MLIR/Implementations/Common.td | 19 ++++++++-- .../CoreDialectsAutoDiffImplementations.cpp | 17 +++++++++ .../CoreDialectsAutoDiffImplementations.h | 7 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 35 ++++++++++++------- 5 files changed, 65 insertions(+), 17 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index 3ed038fa4cb6..3d53793be3af 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -28,6 +28,6 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y), [ (CheckedDivF (DiffeRet), $y), (NegF (MulF (CheckedDivF (DiffeRet), $y), (DivF $x, $y))) - ] - // (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y)) + ], + (CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y)) >; diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 10f9c1532014..fe33089f15e6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -36,11 +36,18 @@ class RegionTerminatorOp { string opName = opName_; } -class MLIRDerivative resultOps> { +class ForwardFromSummedReverseInternal { + int unused = unused_; +} +def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>; + + +class MLIRDerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> { string dialect = dialect_; string opName = opName_; dag PatternToMatch = patternToMatch; list ArgDerivatives = resultOps; + dag ArgDuals = forwardOps; } class Operation { @@ -54,6 +61,9 @@ class DiffeRetIndex indices_> { } def DiffeRet : DiffeRetIndex<[-1]>; +def Shadow : Operation { +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; @@ -62,10 +72,15 @@ class Inst : Operation : Operation { +def SelectIfActive : Operation { + +} + +class ConstantFP : Operation { string value = val; string dialect = dialect_; string opName = op_; + string type = type_; } class ArithInst : Inst; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 025946ef9651..f9b2d61b0dce 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -19,6 +19,23 @@ using namespace mlir; using namespace mlir::enzyme; +mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, + llvm::StringRef value) { + using namespace mlir; + if (auto T = dyn_cast(type)) { + size_t num = 1; + for (auto sz : T.getShape()) + num *= sz; + APFloat apvalue(T.getElementType().cast().getFloatSemantics(), + value); + SmallVector supportedValues(num, apvalue); + return DenseFPElementsAttr::get(type.cast(), supportedValues); + } + auto T = cast(type); + APFloat apvalue(T.getFloatSemantics(), value); + return FloatAttr::get(T, apvalue); +} + void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, OpBuilder &builder, MGradientUtils *gutils) { diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index a7ec6b179986..7ee0be2adb70 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -12,6 +12,9 @@ // //===----------------------------------------------------------------------===// +#ifndef ENZYMEMLIR_CORE_IMPL_H_ +#define ENZYMEMLIR_CORE_IMPL_H_ + #include "Interfaces/AutoDiffOpInterface.h" #include "mlir/Support/LogicalResult.h" @@ -198,5 +201,9 @@ void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); + +mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value); } // namespace enzyme } // namespace mlir + +#endif diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 1e1ac3697095..35c937f79807 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -385,7 +385,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << "({\n"; os << curIndent << INDENT << "// Computing SelectIfActive\n"; - os << curIndent << INDENT << "Value *imVal = nullptr;\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value imVal = nullptr;\n"; + else + os << curIndent << INDENT << "llvm::Value *imVal = nullptr;\n"; os << curIndent << INDENT << "if (!gutils->isConstantValue("; @@ -415,7 +418,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, retidx, origName, newFromOriginal, intrinsic); os << ";\n"; - if (!vector) { + if (!vector && intrinsic != MLIRDerivatives) { os << curIndent << INDENT << INDENT << "llvm::Value* vec_imVal = gutils->getWidth() == 1 ? imVal : " "UndefValue::get(gutils->getShadowType(imVal" @@ -440,16 +443,15 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << curIndent << "})"; return true; } else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) { - if (resultRoot->getNumArgs() != 1) - PrintFatalError(pattern->getLoc(), - "only single op constantfp supported"); - auto value = dyn_cast(Def->getValueInit("value")); if (!value) PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") + resultTree->getAsString()); if (intrinsic == MLIRDerivatives) { + if (resultRoot->getNumArgs() > 1) + PrintFatalError(pattern->getLoc(), + "only zero or single op constantfp supported"); os << builder << ".create<" << cast(Def->getValueInit("dialect"))->getValue() << "::" << cast(Def->getValueInit("opName"))->getValue() @@ -463,9 +465,17 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, assert(!isVec); ord = ord1; } - os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), "; + os << ord << ".getType(), "; + auto typeCast = + dyn_cast(Def->getValueInit("type"))->getValue(); + if (typeCast != "") + os << "(" << typeCast << ")"; + os << "mlir::enzyme::getConstantAttr(" << ord << ".getType(), "; os << "\"" << value->getValue() << "\"))"; } else { + if (resultRoot->getNumArgs() != 1) + PrintFatalError(pattern->getLoc(), + "only single op constantfp supported"); os << "ConstantFP::get("; if (resultRoot->getArgName(0)) { @@ -1269,9 +1279,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, for (Record *pattern : patterns) { DagInit *tree = pattern->getValueAsDag("PatternToMatch"); - DagInit *duals = nullptr; - if (intrinsic != MLIRDerivatives) - duals = pattern->getValueAsDag("ArgDuals"); + DagInit *duals = pattern->getValueAsDag("ArgDuals"); + assert(duals); // Emit RewritePattern for Pattern. ListInit *argOps = pattern->getValueAsListInit("ArgDerivatives"); @@ -1514,8 +1523,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } // TODO - if (!duals || - duals->getOperator()->getAsString() == + if (duals->getOperator()->getAsString() == "ForwardFromSummedReverseInternal" || cast(duals->getOperator()) ->getDef() @@ -1649,7 +1657,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } ArrayRef retidx{}; bool vectorValued = - handle(" ", "fwdnsrarg", os, pattern, duals, "Builder2", + handle(" ", "fwdnsrarg", os, pattern, duals, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", nameToOrdinal, /*lookup*/ false, retidx, origName, /*newFromOriginal*/ true, intrinsic); (void)vectorValued; From 1e1c0eb1c9b4ae3fa6b0acc2394e305b3fc4e042 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 18:00:16 -0500 Subject: [PATCH 073/106] MLIR fix custom fwd tblgen (#1736) --- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 35c937f79807..84dce5f2a08c 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1651,7 +1651,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } else { if (intrinsic == MLIRDerivatives) { - os << " mlir::Value res = nullptr;\n"; + os << " mlir::Value res = "; } else { os << " Value *res = "; } From 2dcf5c5bfbc93805964a4a519665591bdee38046 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 18 Feb 2024 19:00:36 -0500 Subject: [PATCH 074/106] MLIR support globalexpr --- enzyme/Enzyme/MLIR/Implementations/Common.td | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index fe33089f15e6..83b89f7cf0a4 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -64,6 +64,10 @@ def DiffeRet : DiffeRetIndex<[-1]>; def Shadow : Operation { } +class GlobalExpr : Operation{ + string value = val; +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; From f2a9361619ae27c744e32734f0eae89a72a69b4a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 19:57:23 -0500 Subject: [PATCH 075/106] Handle new debug format conversion (#1737) --- enzyme/Enzyme/FunctionUtils.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 160ccc740f8a..fdba018bd600 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -754,6 +754,16 @@ void PreProcessCache::AlwaysInline(Function *NewF) { for (auto CI : ToInline) { InlineFunctionInfo IFI; +#if LLVM_VERSION_MAJOR >= 18 + auto F = CI->getCalledFunction(); + if (CI->getParent()->IsNewDbgInfoFormat != F->IsNewDbgInfoFormat) { + if (CI->getParent()->IsNewDbgInfoFormat) { + F->convertToNewDbgValues(); + } else { + F->convertFromNewDbgValues(); + } + } +#endif InlineFunction(*CI, IFI); } } From f3fca8d1443fb688087f2b09f18299082d668ae5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 22:18:45 -0500 Subject: [PATCH 076/106] TypeAnalysis cleanup null result (#1738) --- enzyme/Enzyme/AdjointGenerator.h | 12 ++--- enzyme/Enzyme/CApi.cpp | 2 +- enzyme/Enzyme/PreserveNVVM.cpp | 8 +-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 57 +++++++++------------ enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h | 5 +- 5 files changed, 39 insertions(+), 45 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ae2e72aceb83..36b337329a2e 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -100,7 +100,7 @@ class AdjointGenerator : public llvm::InstVisitor { using namespace llvm; assert(TR.getFunction() == gutils->oldFunc); - for (auto &pair : TR.analyzer.analysis) { + for (auto &pair : TR.analyzer->analysis) { if (auto in = dyn_cast(pair.first)) { if (in->getParent()->getParent() != gutils->oldFunc) { llvm::errs() << "inf: " << *in->getParent()->getParent() << "\n"; @@ -4152,7 +4152,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (called) { subdata = &gutils->Logic.CreateAugmentedPrimal( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*return is used*/ false, /*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false, gutils->getWidth(), @@ -4368,7 +4368,7 @@ class AdjointGenerator : public llvm::InstVisitor { : nullptr, .forceAnonymousTape = false, .typeInfo = nextTypeInfo}, - TR.analyzer.interprocedural, subdata, + TR.analyzer->interprocedural, subdata, /*omp*/ true); if (subdata->returns.find(AugmentedStruct::Tape) != @@ -4896,7 +4896,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (called) { newcalled = gutils->Logic.CreateForwardDiff( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*returnValue*/ subretused, Mode, ((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(), tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args, @@ -5311,7 +5311,7 @@ class AdjointGenerator : public llvm::InstVisitor { Mode == DerivativeMode::ReverseModeCombined) { subdata = &gutils->Logic.CreateAugmentedPrimal( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*return is used*/ subretused, shadowReturnUsed, nextTypeInfo, overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd); if (Mode == DerivativeMode::ReverseModePrimal) { @@ -5689,7 +5689,7 @@ class AdjointGenerator : public llvm::InstVisitor { .additionalType = tape ? tape->getType() : nullptr, .forceAnonymousTape = false, .typeInfo = nextTypeInfo}, - TR.analyzer.interprocedural, subdata); + TR.analyzer->interprocedural, subdata); if (!newcalled) return; FT = cast(newcalled)->getFunctionType(); diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 15bc6c795aa2..006965496f43 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -297,7 +297,7 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) { void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, LLVMValueRef F) { FnTypeInfo FTI(eunwrap(CTI, cast(unwrap(F)))); - return (void *)&((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; + return (void *)((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; } void *EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G) { diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index a7df39596f0a..84b8b5b9540a 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -60,12 +60,12 @@ using namespace llvm; bool preserveLinkage(bool Begin, Function &F, bool Inlining = true) { if (Begin && !F.hasFnAttribute("prev_fixup")) { F.addFnAttr("prev_fixup"); + if (F.hasFnAttribute(Attribute::AlwaysInline)) + F.addFnAttr("prev_always_inline"); + if (F.hasFnAttribute(Attribute::NoInline)) + F.addFnAttr("prev_no_inline"); if (Inlining) { - if (F.hasFnAttribute(Attribute::AlwaysInline)) - F.addFnAttr("prev_always_inline"); F.removeFnAttr(Attribute::AlwaysInline); - if (F.hasFnAttribute(Attribute::NoInline)) - F.addFnAttr("prev_no_inline"); F.addFnAttr(Attribute::NoInline); } F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 04d1822d4e23..3204b5667086 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -5745,20 +5745,10 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { return TypeResults(analysis); } -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wnull-dereference" -#else -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wnull-dereference" -#endif + if (fn.Function->empty()) - return TypeResults(*(TypeAnalyzer *)nullptr); -#ifdef __clang__ -#pragma clang diagnostic pop -#else -#pragma GCC diagnostic pop -#endif + return TypeResults(nullptr); + auto res = analyzedFunctions.emplace(fn, new TypeAnalyzer(fn, *this)); auto &analysis = *res.first->second; @@ -5808,30 +5798,31 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { return TypeResults(analysis); } -TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(analyzer) {} +TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(&analyzer) {} +TypeResults::TypeResults(std::nullptr_t) : analyzer(nullptr) {} FnTypeInfo TypeResults::getAnalyzedTypeInfo() const { - FnTypeInfo res(analyzer.fntypeinfo.Function); - for (auto &arg : analyzer.fntypeinfo.Function->args()) { + FnTypeInfo res(analyzer->fntypeinfo.Function); + for (auto &arg : analyzer->fntypeinfo.Function->args()) { res.Arguments.insert(std::pair(&arg, query(&arg))); } res.Return = getReturnAnalysis(); - res.KnownValues = analyzer.fntypeinfo.KnownValues; + res.KnownValues = analyzer->fntypeinfo.KnownValues; return res; } FnTypeInfo TypeResults::getCallInfo(CallBase &CI, Function &fn) const { - return analyzer.getCallInfo(CI, fn); + return analyzer->getCallInfo(CI, fn); } TypeTree TypeResults::query(Value *val) const { if (auto inst = dyn_cast(val)) { - assert(inst->getParent()->getParent() == analyzer.fntypeinfo.Function); + assert(inst->getParent()->getParent() == analyzer->fntypeinfo.Function); } if (auto arg = dyn_cast(val)) { - assert(arg->getParent() == analyzer.fntypeinfo.Function); + assert(arg->getParent() == analyzer->fntypeinfo.Function); } - return analyzer.getAnalysis(val); + return analyzer->getAnalysis(val); } bool TypeResults::anyFloat(Value *val) const { @@ -5843,7 +5834,7 @@ bool TypeResults::anyFloat(Value *val) const { return dt.isFloat(); size_t ObjSize = 1; - auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); if (val->getType()->isSized()) ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; @@ -5871,7 +5862,7 @@ bool TypeResults::anyPointer(Value *val) const { return dt == BaseType::Pointer; size_t ObjSize = 1; - auto &dl = analyzer.fntypeinfo.Function->getParent()->getDataLayout(); + auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); if (val->getType()->isSized()) ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; @@ -5890,7 +5881,7 @@ bool TypeResults::anyPointer(Value *val) const { return false; } -void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer.dump(ss); } +void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer->dump(ss); } ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, bool pointerIntSame) const { @@ -5913,7 +5904,7 @@ ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; llvm::errs() << *inst->getParent()->getParent() << "\n"; - for (auto &pair : analyzer.analysis) { + for (auto &pair : analyzer->analysis) { llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() << "\n"; } @@ -5948,7 +5939,7 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, assert(val->getType()); auto q = query(val).Data0(); if (!(val->getType()->isPointerTy() || q[{}] == BaseType::Pointer)) { - llvm::errs() << *analyzer.fntypeinfo.Function << "\n"; + llvm::errs() << *analyzer->fntypeinfo.Function << "\n"; dump(); llvm::errs() << "val: " << *val << "\n"; } @@ -5973,7 +5964,7 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, } if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) { - auto &res = analyzer; + auto &res = *analyzer; if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; llvm::errs() << *inst->getParent()->getParent() << "\n"; @@ -6006,15 +5997,15 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, << "\n"; } } - llvm::errs() << "fn: " << *analyzer.fntypeinfo.Function << "\n"; + llvm::errs() << "fn: " << *analyzer->fntypeinfo.Function << "\n"; dump(); llvm::errs() << "could not deduce type of integer " << *val << " num:" << num << " q:" << q.str() << " \n"; llvm::DiagnosticLocation loc = - analyzer.fntypeinfo.Function->getSubprogram(); + analyzer->fntypeinfo.Function->getSubprogram(); Instruction *codeLoc = - &*analyzer.fntypeinfo.Function->getEntryBlock().begin(); + &*analyzer->fntypeinfo.Function->getEntryBlock().begin(); if (auto inst = dyn_cast(val)) { loc = inst->getDebugLoc(); codeLoc = inst; @@ -6117,15 +6108,15 @@ TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, } Function *TypeResults::getFunction() const { - return analyzer.fntypeinfo.Function; + return analyzer->fntypeinfo.Function; } TypeTree TypeResults::getReturnAnalysis() const { - return analyzer.getReturnAnalysis(); + return analyzer->getReturnAnalysis(); } std::set TypeResults::knownIntegralValues(Value *val) const { - return analyzer.knownIntegralValues(val); + return analyzer->knownIntegralValues(val); } std::set TypeAnalyzer::knownIntegralValues(Value *val) { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h index 6e036454143a..19c8e2d4d19f 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h @@ -156,9 +156,10 @@ class TypeAnalysis; /// on a given function class TypeResults { public: - TypeAnalyzer &analyzer; + TypeAnalyzer *analyzer; public: + TypeResults(std::nullptr_t); TypeResults(TypeAnalyzer &analyzer); ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound = true, bool pointerIntSame = false) const; @@ -258,6 +259,8 @@ class TypeAnalyzer : public llvm::InstVisitor { FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn); + TypeAnalyzer(TypeAnalysis &TA); + TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA, uint8_t direction = BOTH); From 026f8a4d26dad002f65e0d92dc9d71103a269290 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 23:04:34 -0500 Subject: [PATCH 077/106] Cleanup gutils to avoid null reference (#1739) --- enzyme/Enzyme/AdjointGenerator.h | 4 +- enzyme/Enzyme/CallDerivatives.cpp | 2 +- enzyme/Enzyme/EnzymeLogic.cpp | 26 ++--- enzyme/Enzyme/GradientUtils.cpp | 154 +++++++++++++++--------------- enzyme/Enzyme/GradientUtils.h | 10 +- 5 files changed, 98 insertions(+), 98 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 36b337329a2e..6a0af0d2b78b 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1542,7 +1542,7 @@ class AdjointGenerator : public llvm::InstVisitor { lc) && gutils->getNewFromOriginal(P0->getParent()) == lc.header) { SmallVector Latches; - gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (&SI != P0->getIncomingValueForBlock(Latch)) { @@ -2206,7 +2206,7 @@ class AdjointGenerator : public llvm::InstVisitor { lc) && gutils->getNewFromOriginal(P0->getParent()) == lc.header) { SmallVector Latches; - gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (&BO != P0->getIncomingValueForBlock(Latch)) { diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 155967e24a33..8c77a51f954d 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -3323,7 +3323,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( // rematerialization is loop level. This is because one can have a // loop level cache, but a function level allocation (e.g. for stack // allocas). If we deleted it here, we would have no allocation! - auto AllocationLoop = gutils->OrigLI.getLoopFor(call.getParent()); + auto AllocationLoop = gutils->OrigLI->getLoopFor(call.getParent()); // An allocation within a loop, must definitionally be a loop level // allocation (but not always the other way around. if (AllocationLoop) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 731826bb793f..120722dc7fb6 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -981,7 +981,7 @@ void calculateUnusedValuesInFunction( if (newMemory) { bool foundStore = false; allInstructionsBetween( - gutils->OrigLI, cast(at), + *gutils->OrigLI, cast(at), const_cast(mti), [&](Instruction *I) -> bool { if (!I->mayWriteToMemory()) @@ -994,7 +994,7 @@ void calculateUnusedValuesInFunction( } if (writesToMemoryReadBy( - gutils->OrigAA, TLI, + *gutils->OrigAA, TLI, /*maybeReader*/ const_cast(mti), /*maybeWriter*/ I)) { foundStore = true; @@ -1143,7 +1143,7 @@ void calculateUnusedStoresInFunction( if (newMemory) { bool foundStore = false; allInstructionsBetween( - gutils->OrigLI, cast(at), + *gutils->OrigLI, cast(at), const_cast(mti), [&](Instruction *I) -> bool { if (!I->mayWriteToMemory()) return /*earlyBreak*/ false; @@ -1152,7 +1152,7 @@ void calculateUnusedStoresInFunction( // if (I == &MTI) return; if (writesToMemoryReadBy( - gutils->OrigAA, TLI, + *gutils->OrigAA, TLI, /*maybeReader*/ const_cast(mti), /*maybeWriter*/ I)) { foundStore = true; @@ -1552,7 +1552,7 @@ bool legalCombinedForwardReverse( auto consider = [&](Instruction *user) { if (!user->mayReadFromMemory()) return false; - if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI, + if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI, /*maybeReader*/ user, /*maybeWriter*/ inst)) { @@ -1585,7 +1585,7 @@ bool legalCombinedForwardReverse( if (!post->mayWriteToMemory()) return false; - if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI, + if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI, /*maybeReader*/ inst, /*maybeWriter*/ post)) { if (EnzymePrintPerf) { @@ -2398,9 +2398,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, gutils->rematerializableAllocations, gutils->TR, - gutils->OrigAA, gutils->oldFunc, + *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, DerivativeMode::ReverseModePrimal, omp); const std::map> overwritten_args_map = CA.compute_overwritten_args_for_callsites(); @@ -3346,7 +3346,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, gutils->getNewFromOriginal(orig->getParent()) == loopContext.header && loopContext.exitBlocks.size() == 1) { SmallVector Latches; - gutils->OrigLI.getLoopFor(orig->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(orig->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (activeUses[0] != orig->getIncomingValueForBlock(Latch)) { @@ -4080,9 +4080,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient( gutils->computeGuaranteedFrees(); CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, gutils->rematerializableAllocations, gutils->TR, - gutils->OrigAA, gutils->oldFunc, + *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, key.mode, omp); const std::map> overwritten_args_map = (augmenteddata) ? augmenteddata->overwritten_args_map @@ -4734,10 +4734,10 @@ Function *EnzymeLogic::CreateForwardDiff( gutils->computeGuaranteedFrees(); CacheAnalysis CA( gutils->allocationsWithGuaranteedFree, - gutils->rematerializableAllocations, gutils->TR, gutils->OrigAA, + gutils->rematerializableAllocations, gutils->TR, *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, mode, omp); const std::map> overwritten_args_map = CA.compute_overwritten_args_for_callsites(); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index d1dedbc70353..9025b890a2f3 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -169,19 +169,19 @@ GradientUtils::GradientUtils( : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(), OrigDT(oldFunc_->empty() - ? *((DominatorTree *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((DominatorTree *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), OrigPDT(oldFunc_->empty() - ? *((PostDominatorTree *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((PostDominatorTree *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), OrigLI(oldFunc_->empty() - ? *((LoopInfo *)nullptr) - : Logic.PPC.FAM.getResult(*oldFunc_)), + ? ((LoopInfo *)nullptr) + : &Logic.PPC.FAM.getResult(*oldFunc_)), OrigSE(oldFunc_->empty() - ? *((ScalarEvolution *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((ScalarEvolution *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), notForAnalysis(getGuaranteedUnreachable(oldFunc_)), ATA(oldFunc_->empty() @@ -191,8 +191,8 @@ GradientUtils::GradientUtils( notForAnalysis, TLI_, constantvalues_, activevals_, ReturnActivity)), overwritten_args_map_ptr(nullptr), tid(nullptr), numThreads(nullptr), - OrigAA(oldFunc_->empty() ? *((AAResults *)nullptr) - : Logic.PPC.getAAResultsFromFunction(oldFunc_)), + OrigAA(oldFunc_->empty() ? ((AAResults *)nullptr) + : &Logic.PPC.getAAResultsFromFunction(oldFunc_)), TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) { if (oldFunc_->empty()) return; @@ -242,7 +242,7 @@ GradientUtils::GradientUtils( for (BasicBlock &BB : *oldFunc) { bool legal = true; for (auto BRet : ReturningBlocks) { - if (!(BRet == &BB || OrigDT.dominates(&BB, BRet))) { + if (!(BRet == &BB || OrigDT->dominates(&BB, BRet))) { legal = false; break; } @@ -508,7 +508,7 @@ Value *GradientUtils::getOrInsertConditionalIndex(Value *val, LoopContext &lc, bool GradientUtils::assumeDynamicLoopOfSizeOne(Loop *L) const { if (!EnzymeInactiveDynamic) return false; - auto OL = OrigLI.getLoopFor(isOriginal(L->getHeader())); + auto OL = OrigLI->getLoopFor(isOriginal(L->getHeader())); assert(OL); for (auto OB : OL->getBlocks()) { for (auto &OI : *OB) { @@ -3788,7 +3788,7 @@ bool GradientUtils::legalRecompute(const Value *val, struct { Function *func; const LoopInfo &FLI; - } options[2] = {{newFunc, LI}, {oldFunc, OrigLI}}; + } options[2] = {{newFunc, LI}, {oldFunc, *OrigLI}}; for (const auto &tup : options) { if (parent->getParent() == tup.func) { for (auto &val : phi->incoming_values()) { @@ -3928,7 +3928,7 @@ bool GradientUtils::legalRecompute(const Value *val, const_cast(orig), [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && writesToMemoryReadBy( - OrigAA, TLI, + *OrigAA, TLI, /*maybeReader*/ const_cast(orig), /*maybeWriter*/ I)) { failed = true; @@ -3951,7 +3951,7 @@ bool GradientUtils::legalRecompute(const Value *val, } origStart = origStart->getNextNode(); } while (true); - if (OrigDT.dominates(origStart, const_cast(orig))) { + if (OrigDT->dominates(origStart, const_cast(orig))) { bool failed = false; allInstructionsBetween( @@ -3959,7 +3959,7 @@ bool GradientUtils::legalRecompute(const Value *val, const_cast(orig), [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && writesToMemoryReadBy( - OrigAA, TLI, + *OrigAA, TLI, /*maybeReader*/ const_cast(orig), /*maybeWriter*/ I)) { failed = true; @@ -5290,7 +5290,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, if (F && isMemFreeLibMFunction(F->getName())) { continue; } - if (llvm::isModOrRefSet(OrigAA.getModRefInfo(CI, Loc))) { + if (llvm::isModOrRefSet(OrigAA->getModRefInfo(CI, Loc))) { seen = true; llvm::errs() << " cannot shadow-inline global " << *oval << " due to " << *CI << "\n"; @@ -6336,9 +6336,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // + or because the loop nests share no ancestry bool loopLegal = true; - for (Loop *idx = OrigLI.getLoopFor(orig); idx != nullptr; + for (Loop *idx = OrigLI->getLoopFor(orig); idx != nullptr; idx = idx->getParentLoop()) { - for (Loop *fdx = OrigLI.getLoopFor(forwardBlock); fdx != nullptr; + for (Loop *fdx = OrigLI->getLoopFor(forwardBlock); fdx != nullptr; fdx = fdx->getParentLoop()) { if (idx == fdx) { loopLegal = false; @@ -6542,9 +6542,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // << "\n"; allInstructionsBetween( - OrigLI, orig2, origInst, [&](Instruction *I) -> bool { + *OrigLI, orig2, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -6558,12 +6558,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (auto ar1 = dyn_cast(scev1)) { if (auto ar2 = dyn_cast(scev2)) { - if (ar1->getStart() != OrigSE.getCouldNotCompute() && + if (ar1->getStart() != OrigSE->getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar1->getStepRecurrence(OrigSE) != - OrigSE.getCouldNotCompute() && - ar1->getStepRecurrence(OrigSE) == - ar2->getStepRecurrence(OrigSE)) { + ar1->getStepRecurrence(*OrigSE) != + OrigSE->getCouldNotCompute() && + ar1->getStepRecurrence(*OrigSE) == + ar2->getStepRecurrence(*OrigSE)) { LoopContext l1; getContext(ar1->getLoop()->getHeader(), l1); @@ -6591,7 +6591,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); + auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); auto Arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); @@ -6599,7 +6599,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Arch == Triple::amdgcn ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local : 3; - if (EnzymeSharedForward && scev1 != OrigSE.getCouldNotCompute() && + if (EnzymeSharedForward && scev1 != OrigSE->getCouldNotCompute() && cast(orig_liobj->getType())->getAddressSpace() == SharedAddrSpace) { Value *resultValue = nullptr; @@ -6608,7 +6608,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, assert(pair.first->getType() == pair.second->getType()); newavail[pair.first] = pair.second; } - allDomPredecessorsOf(origInst, OrigDT, [&](Instruction *pred) { + allDomPredecessorsOf(origInst, *OrigDT, [&](Instruction *pred) { if (auto SI = dyn_cast(pred)) { // auto NewSI = cast(getNewFromOriginal(SI)); auto si2obj = getBaseObject(SI->getPointerOperand()); @@ -6619,10 +6619,10 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool lastStore = true; bool interveningSync = false; allInstructionsBetween( - OrigLI, SI, origInst, [&](Instruction *potentialAlias) { + *OrigLI, SI, origInst, [&](Instruction *potentialAlias) { if (!potentialAlias->mayWriteToMemory()) return false; - if (!writesToMemoryReadBy(OrigAA, TLI, origInst, + if (!writesToMemoryReadBy(*OrigAA, TLI, origInst, potentialAlias)) return false; @@ -6640,7 +6640,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (mid == SI) return false; - if (!writesToMemoryReadBy(OrigAA, TLI, origInst, + if (!writesToMemoryReadBy(*OrigAA, TLI, origInst, mid)) { return false; } @@ -6667,16 +6667,16 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (!lastStore) return false; - auto scev2 = OrigSE.getSCEV(SI->getPointerOperand()); + auto scev2 = OrigSE->getSCEV(SI->getPointerOperand()); bool legal = scev1 == scev2; if (auto ar2 = dyn_cast(scev2)) { if (auto ar1 = dyn_cast(scev1)) { - if (ar2->getStart() != OrigSE.getCouldNotCompute() && + if (ar2->getStart() != OrigSE->getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar2->getStepRecurrence(OrigSE) != - OrigSE.getCouldNotCompute() && - ar1->getStepRecurrence(OrigSE) == - ar2->getStepRecurrence(OrigSE)) { + ar2->getStepRecurrence(*OrigSE) != + OrigSE->getCouldNotCompute() && + ar1->getStepRecurrence(*OrigSE) == + ar2->getStepRecurrence(*OrigSE)) { LoopContext l1; getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), @@ -6729,15 +6729,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, ValueToValueMapTy ThreadLookup; bool legal = true; for (size_t i = 0; i < svals.size(); i++) { - auto ss = OrigSE.getSCEV(svals[i]); - auto ls = OrigSE.getSCEV(lvals[i]); + auto ss = OrigSE->getSCEV(svals[i]); + auto ls = OrigSE->getSCEV(lvals[i]); if (cast(ss->getType())->getBitWidth() > cast(ls->getType())->getBitWidth()) { - ls = OrigSE.getZeroExtendExpr(ls, ss->getType()); + ls = OrigSE->getZeroExtendExpr(ls, ss->getType()); } if (cast(ss->getType())->getBitWidth() < cast(ls->getType())->getBitWidth()) { - ls = OrigSE.getTruncateExpr(ls, ss->getType()); + ls = OrigSE->getTruncateExpr(ls, ss->getType()); } if (ls != ss) { if (auto II = dyn_cast(svals[i])) { @@ -6824,7 +6824,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto origPH = cast_or_null(isOriginal(ctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { goto noSpeedCache; } @@ -6837,10 +6837,10 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; allInstructionsBetween( - OrigLI, &*origTerm, origInst, + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ tmpload, /*maybeWriter*/ I)) { failed = true; @@ -6858,15 +6858,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; auto origPH = cast_or_null(isOriginal(nctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { break; } Instruction *origTerm = origPH->getTerminator(); allInstructionsBetween( - OrigLI, &*origTerm, origInst, + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ tmpload, /*maybeWriter*/ I)) { failed = true; @@ -6958,7 +6958,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); + auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); // Store in memcpy opt Value *lim = nullptr; BasicBlock *ctx = nullptr; @@ -6966,7 +6966,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Value *offset = nullptr; if (auto ar1 = dyn_cast(scev1)) { if (auto step = - dyn_cast(ar1->getStepRecurrence(OrigSE))) { + dyn_cast(ar1->getStepRecurrence(*OrigSE))) { if (step->getAPInt() != loadSize) goto noSpeedCache; @@ -6983,7 +6983,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto origPH = cast_or_null(isOriginal(ctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { goto noSpeedCache; } @@ -7002,7 +7002,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, SmallVector InsertedInstructions; { SCEVExpander OrigExp( - OrigSE, ctx->getParent()->getParent()->getDataLayout(), + *OrigSE, ctx->getParent()->getParent()->getDataLayout(), "enzyme"); OrigExp.setInsertPoint( @@ -7023,7 +7023,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // instructions. llvm::stable_sort(InsertedInstructions, [this](Instruction *A, Instruction *B) { - return OrigDT.dominates(A, B); + return OrigDT->dominates(A, B); }); for (auto a : InsertedInstructions) { assert(!isa(a)); @@ -7054,7 +7054,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, available.clear(); for (auto I : llvm::reverse(InsertedInstructions)) { assert(I->getNumUses() == 0); - OrigSE.forgetValue(I); + OrigSE->forgetValue(I); I->eraseFromParent(); } #endif @@ -7067,9 +7067,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; allInstructionsBetween( - OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -7091,14 +7091,14 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; auto origPH = cast_or_null(isOriginal(nctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { break; } Instruction *origTerm = origPH->getTerminator(); allInstructionsBetween( - OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -7891,7 +7891,7 @@ void GradientUtils::computeMinCache() { for (BasicBlock &BB : *oldFunc) { if (notForAnalysis.count(&BB)) continue; - auto L = OrigLI.getLoopFor(&BB); + auto L = OrigLI->getLoopFor(&BB); auto invariant = [&](Value *V) { if (isa(V)) @@ -7899,20 +7899,20 @@ void GradientUtils::computeMinCache() { if (isa(V)) return true; if (auto I = dyn_cast(V)) { - if (!L->contains(OrigLI.getLoopFor(I->getParent()))) + if (!L->contains(OrigLI->getLoopFor(I->getParent()))) return true; } return false; }; for (Instruction &I : BB) { if (auto PN = dyn_cast(&I)) { - if (!OrigLI.isLoopHeader(&BB)) + if (!OrigLI->isLoopHeader(&BB)) continue; if (PN->getType()->isIntegerTy()) { bool legal = true; SmallPtrSet Increment; for (auto B : PN->blocks()) { - if (OrigLI.getLoopFor(B) == L) { + if (OrigLI->getLoopFor(B) == L) { if (auto BO = dyn_cast( PN->getIncomingValueForBlock(B))) { if (BO->getOpcode() == BinaryOperator::Add) { @@ -8000,7 +8000,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(&BB); L != nullptr; + for (Loop *L = OrigLI->getLoopFor(&BB); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8042,7 +8042,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(cast(V)->getParent()); + for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8068,8 +8068,8 @@ void GradientUtils::computeMinCache() { SetVector MinReq; DifferentialUseAnalysis::minCut(oldFunc->getParent()->getDataLayout(), - OrigLI, Recomputes, Intermediates, Required, - MinReq, this, TLI); + *OrigLI, Recomputes, Intermediates, + Required, MinReq, this, TLI); SmallPtrSet NeedGraph; for (Value *V : MinReq) NeedGraph.insert(V); @@ -8098,7 +8098,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(cast(V)->getParent()); + for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8746,13 +8746,13 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { } // Find the outermost loop of all stores, and the allocation/lifetime - Loop *outer = OrigLI.getLoopFor(V->getParent()); + Loop *outer = OrigLI->getLoopFor(V->getParent()); if (LifetimeStarts.size() == 1) { - outer = OrigLI.getLoopFor((*LifetimeStarts.begin())->getParent()); + outer = OrigLI->getLoopFor((*LifetimeStarts.begin())->getParent()); } for (auto S : stores) { - outer = getAncestor(outer, OrigLI.getLoopFor(S->getParent())); + outer = getAncestor(outer, OrigLI->getLoopFor(S->getParent())); } // May now read pointers for storing into other pointers. Therefore we @@ -8766,8 +8766,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, - outer)) { + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI, + res, outer)) { EmitWarning("NotPromotable", *LI, " Could not promote shadow allocation ", *V, " due to pointer load ", *LI, @@ -8821,7 +8821,7 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI, res, outer)) { EmitWarning("NotPromotable", *LI, " Could not promote allocation ", *V, " due to load ", *LI, @@ -8837,8 +8837,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI.loadCall, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI.loadCall, - res, outer)) { + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, + LI.loadCall, res, outer)) { EmitWarning("NotPromotable", *LI.loadCall, " Could not promote allocation ", *V, " due to load-like call ", *LI.loadCall, @@ -9049,7 +9049,7 @@ void GradientUtils::computeGuaranteedFrees() { bool hasPDFree = false; if (dc->getParent() == CI->getParent() || - OrigPDT.dominates(CI->getParent(), dc->getParent())) { + OrigPDT->dominates(CI->getParent(), dc->getParent())) { hasPDFree = true; } diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 6a8c9fd61b9e..7552bec38dab 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -128,10 +128,10 @@ class GradientUtils : public CacheUtility { DerivativeMode mode; llvm::Function *oldFunc; llvm::ValueMap invertedPointers; - llvm::DominatorTree &OrigDT; - llvm::PostDominatorTree &OrigPDT; - llvm::LoopInfo &OrigLI; - llvm::ScalarEvolution &OrigSE; + llvm::DominatorTree *OrigDT; + llvm::PostDominatorTree *OrigPDT; + llvm::LoopInfo *OrigLI; + llvm::ScalarEvolution *OrigSE; /// (Original) Blocks which dominate all returns llvm::SmallPtrSet BlocksDominatingAllReturns; @@ -353,7 +353,7 @@ class GradientUtils : public CacheUtility { } public: - llvm::AAResults &OrigAA; + llvm::AAResults *OrigAA; TypeAnalysis &TA; TypeResults TR; bool omp; From cc7ef6f2c885301d49e3ad452d1be018e9974be5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 18 Feb 2024 23:13:29 -0500 Subject: [PATCH 078/106] Fix integration test if not built with compiler-rt (#1740) --- enzyme/test/Integration/ReverseMode/dbginfo.c | 15 +++++++++++++++ enzyme/test/Integration/ReverseMode/taylorlog.c | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/enzyme/test/Integration/ReverseMode/dbginfo.c b/enzyme/test/Integration/ReverseMode/dbginfo.c index 31d9bafe4439..b094dffb5c6d 100644 --- a/enzyme/test/Integration/ReverseMode/dbginfo.c +++ b/enzyme/test/Integration/ReverseMode/dbginfo.c @@ -13,6 +13,21 @@ double __enzyme_autodiff(void*, double, unsigned); +// May be needed if not built with compiler-rt +double __powidf2(double a, int b) { + const int recip = b < 0; + double r = 1; + while (1) { + if (b & 1) + r *= a; + b /= 2; + if (b == 0) + break; + a *= a; + } + return recip ? 1 / r : r; +} + static double taylorlog(double x, unsigned SINCOSN) { double sum = 0; for(int i=1; i<=SINCOSN; i++) { diff --git a/enzyme/test/Integration/ReverseMode/taylorlog.c b/enzyme/test/Integration/ReverseMode/taylorlog.c index fdecb8aac7af..522928e3f60f 100644 --- a/enzyme/test/Integration/ReverseMode/taylorlog.c +++ b/enzyme/test/Integration/ReverseMode/taylorlog.c @@ -13,6 +13,21 @@ double __enzyme_autodiff(void*, double, unsigned); +// May be needed if not built with compiler-rt +double __powidf2(double a, int b) { + const int recip = b < 0; + double r = 1; + while (1) { + if (b & 1) + r *= a; + b /= 2; + if (b == 0) + break; + a *= a; + } + return recip ? 1 / r : r; +} + static double taylorlog(double x, unsigned SINCOSN) { double sum = 0; for(int i=1; i<=SINCOSN; i++) { From b0cee973864ca6a6bbf6e6a7fe96cab00a7f7550 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 19 Feb 2024 18:55:08 -0500 Subject: [PATCH 079/106] Test assert-less build (#1742) --- enzyme/CMakeLists.txt | 2 +- enzyme/Enzyme/ActivityAnalysis.cpp | 7 ++++-- enzyme/Enzyme/AdjointGenerator.h | 4 +++ enzyme/Enzyme/CApi.cpp | 2 ++ enzyme/Enzyme/CacheUtility.cpp | 3 +++ enzyme/Enzyme/CallDerivatives.cpp | 1 + enzyme/Enzyme/Clang/EnzymeClang.cpp | 4 +-- enzyme/Enzyme/DiffeGradientUtils.cpp | 9 +++++++ enzyme/Enzyme/DifferentialUseAnalysis.cpp | 7 ++++-- enzyme/Enzyme/EnzymeLogic.cpp | 2 ++ enzyme/Enzyme/FunctionUtils.cpp | 2 ++ enzyme/Enzyme/GradientUtils.cpp | 25 ++++++++++++++++--- enzyme/Enzyme/GradientUtils.h | 2 ++ enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 3 +-- enzyme/Enzyme/TypeAnalysis/BaseType.h | 2 ++ enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp | 5 ++-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 5 ++++ enzyme/Enzyme/TypeAnalysis/TypeTree.h | 2 ++ enzyme/Enzyme/Utils.h | 9 +++++-- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 2 ++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 4 +-- 21 files changed, 84 insertions(+), 18 deletions(-) diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 13a9b4973c54..32d09c19e302 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -14,7 +14,7 @@ add_definitions(-DENZYME_VERSION_MINOR=${ENZYME_MINOR_VERSION}) add_definitions(-DENZYME_VERSION_PATCH=${ENZYME_PATCH_VERSION}) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else") +SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull") SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "-O2") diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 243e71dec04c..274c2fcea183 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -1049,9 +1049,11 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } assert(TR.getFunction() == I->getParent()->getParent()); } +#ifndef NDEBUG if (auto Arg = dyn_cast(Val)) { assert(TR.getFunction() == Arg->getParent()); } +#endif // Void values are definitionally inactive if (Val->getType()->isVoidTy()) @@ -2305,6 +2307,9 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { // this value is inactive, we are inactive Since we won't look at uses to // prove, we can inductively assume this is inactive if (directions & UP) { + if (!UpHypothesis) + UpHypothesis = + std::shared_ptr(new ActivityAnalyzer(*this, UP)); if (directions == UP && !isa(Val)) { if (isInstructionInactiveFromOrigin(TR, Val, true)) { InsertConstantValue(TR, Val); @@ -2320,8 +2325,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } } } else { - UpHypothesis = - std::shared_ptr(new ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantValues.insert(Val); if (UpHypothesis->isInstructionInactiveFromOrigin(TR, Val, true)) { insertConstantsFrom(TR, *UpHypothesis); diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 6a0af0d2b78b..dbcd0947d342 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -408,6 +408,7 @@ class AdjointGenerator : public llvm::InstVisitor { constantval |= gutils->isConstantValue(&I); Type *type = gutils->getShadowType(I.getType()); + (void)type; auto *newi = dyn_cast(gutils->getNewFromOriginal(&I)); @@ -621,6 +622,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (primalNeededInReverse) { inst = gutils->cacheForReverse(BuilderZ, newi, getIndex(&I, CacheType::Self, BuilderZ)); + (void)inst; assert(inst->getType() == type); if (Mode == DerivativeMode::ReverseModeGradient || @@ -3777,6 +3779,7 @@ class AdjointGenerator : public llvm::InstVisitor { setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), Builder2); } + (void)vdiff; switch (ID) { @@ -5201,6 +5204,7 @@ class AdjointGenerator : public llvm::InstVisitor { // Note sometimes whattype mistakenly says something should be // constant [because composed of integer pointers alone] + (void)argType; assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); } else { diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 006965496f43..8c1c295b54f6 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -1184,6 +1184,7 @@ LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, for (auto s : sub) { uint64_t ival; bool b = s.getAsInteger(10, ival); + (void)b; assert(!b); previdx.push_back(ival); } @@ -1241,6 +1242,7 @@ LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, APInt Offset(width, 0); bool success = collectOffset(cast(gep), DL, width, VariableOffsets, Offset); + (void)success; assert(success); Value *start = ConstantInt::get(T, Offset); for (auto &pair : VariableOffsets) diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index 68c6e46784cf..cba31f314000 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -200,6 +200,7 @@ std::pair FindCanonicalIV(Loop *L, Type *Ty) { } llvm::errs() << *Header << "\n"; assert(0 && "Could not find canonical IV"); + return std::pair(nullptr, nullptr); } // Attempt to rewrite all phinode's in the loop in terms of the @@ -1330,8 +1331,10 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx, IRBuilder<> &BuilderM, Value *val, AllocaInst *cache, MDNode *TBAA) { assert(BuilderM.GetInsertBlock()->getParent() == newFunc); +#ifndef NDEBUG if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == newFunc); +#endif IRBuilder<> v(BuilderM.GetInsertBlock()); v.SetInsertPoint(BuilderM.GetInsertBlock(), BuilderM.GetInsertPoint()); v.setFastMathFlags(getFast()); diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 8c77a51f954d..6280dc9f1b44 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -4047,6 +4047,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( return true; } assert(!unnecessaryValues.count(rmat.first)); + (void)primalNeededInReverse; assert(primalNeededInReverse); } } diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index ed01f1bf5739..a34a6429dcf7 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -141,10 +141,10 @@ class EnzymePlugin final : public clang::ASTConsumer { using namespace clang; DeclGroupRef::iterator it; - Visitor v(CI); + // Visitor v(CI); // Forcibly require emission of all libdevice for (it = dg.begin(); it != dg.end(); ++it) { - v.TraverseDecl(*it); + // v.TraverseDecl(*it); if (auto FD = dyn_cast(*it)) { if (!FD->hasAttr()) continue; diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 43a14051c4fd..552d2894d759 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -164,10 +164,12 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( AllocaInst *DiffeGradientUtils::getDifferential(Value *val) { assert(val); +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif assert(inversionAllocs); Type *type = getShadowType(val->getType()); @@ -195,10 +197,12 @@ AllocaInst *DiffeGradientUtils::getDifferential(Value *val) { } Value *DiffeGradientUtils::diffe(Value *val, IRBuilder<> &BuilderM) { +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif if (isConstantValue(val)) { llvm::errs() << *newFunc << "\n"; @@ -336,6 +340,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, llvm::errs() << "} start=" << start << " size=" << size << " storeSize=" << storeSize << " val=" << *val << "\n"; assert(0 && "unhandled accumulate with partial sizes"); + return {}; } SmallVector @@ -345,10 +350,12 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, assert(mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ReverseModeCombined); +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif SmallVector addedSelects; @@ -659,6 +666,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, void DiffeGradientUtils::setDiffe(Value *val, Value *toset, IRBuilder<> &BuilderM) { +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) @@ -668,6 +676,7 @@ void DiffeGradientUtils::setDiffe(Value *val, Value *toset, llvm::errs() << *val << "\n"; } assert(!isConstantValue(val)); +#endif toset = SanitizeDerivatives(val, toset, BuilderM); if (mode == DerivativeMode::ForwardMode || mode == DerivativeMode::ForwardModeSplit) { diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 8f2e27b3140e..b0b8c48ab5c9 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -54,9 +54,11 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( const SmallPtrSetImpl &oldUnreachable, QueryType qtype, bool *recursiveUse) { TypeResults const &TR = gutils->TR; +#ifndef NDEBUG if (auto ainst = dyn_cast(val)) { assert(ainst->getParent()->getParent() == gutils->oldFunc); } +#endif bool shadow = qtype == QueryType::Shadow || qtype == QueryType::ShadowByConstPrimal; @@ -79,8 +81,7 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( if (!user) { if (EnzymePrintDiffUse) - llvm::errs() << " Need: of " << *val << " in reverse as unknown user " - << *user << "\n"; + llvm::errs() << " Need: of " << *val << " in reverse as nullptr user\n"; return true; } @@ -794,12 +795,14 @@ void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, } } } +#ifndef NDEBUG for (auto R : Required) { assert(Intermediates.count(R)); } for (auto R : Recomputes) { assert(Intermediates.count(R)); } +#endif Graph Orig = G; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 120722dc7fb6..0767b3e52529 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3632,6 +3632,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( if (hasMetadata(key.todiff, "enzyme_gradient")) { std::set seen; +#ifndef NDEBUG DIFFE_TYPE subretType = whatType(key.todiff->getReturnType(), DerivativeMode::ReverseModeGradient, /*intAreConstant*/ false, seen); @@ -3639,6 +3640,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( key.todiff->getReturnType()->isEmptyTy()) subretType = DIFFE_TYPE::CONSTANT; assert(subretType == key.retType); +#endif if (key.mode == DerivativeMode::ReverseModeCombined) { auto res = getDefaultFunctionTypeForGradient( diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index fdba018bd600..218ca9549dfc 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -6295,6 +6295,7 @@ class Constraints : public std::enable_shared_from_this { assert(t != Type::None); assert(c.size() != 0); assert(c.size() != 1); +#ifndef NDEBUG SmallVector tmp(c.begin(), c.end()); for (unsigned i = 0; i < tmp.size(); i++) for (unsigned j = 0; j < i; j++) @@ -6317,6 +6318,7 @@ class Constraints : public std::enable_shared_from_this { if (auto s = dyn_cast(tmp[j]->node)) assert(s->getLoop() != tmp[i]->Loop); } +#endif } bool operator==(const Constraints &rhs) const { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 9025b890a2f3..57f21d8d9785 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -607,12 +607,14 @@ BasicBlock *GradientUtils::getOriginalFromNew(const BasicBlock *newinst) const { Value *GradientUtils::isOriginal(const Value *newinst) const { if (isa(newinst) || isa(newinst)) return const_cast(newinst); +#ifndef NDEBUG if (auto arg = dyn_cast(newinst)) { assert(arg->getParent() == newFunc); } if (auto inst = dyn_cast(newinst)) { assert(inst->getParent()->getParent() == newFunc); } +#endif auto found = newToOriginalFn.find(newinst); if (found == newToOriginalFn.end()) return nullptr; @@ -2519,11 +2521,13 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, return malloc; } +#ifndef NDEBUG if (auto CI = dyn_cast(malloc)) { if (auto F = CI->getCalledFunction()) { assert(F->getName() != "omp_get_thread_num"); } } +#endif if (malloc->getType()->isTokenTy()) { llvm::errs() << " oldFunc: " << *oldFunc << "\n"; @@ -3425,8 +3429,9 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { lctx, placeholder->getType(), placeholder->getName(), /*shouldFree*/ true); assert(cache); + Value *placeholder_tmp = placeholder; found = insert_or_assign( - scopeMap, (Value *&)placeholder, + scopeMap, placeholder_tmp, std::pair, LimitContext>(cache, lctx)); } auto cache = found->second.first; @@ -4754,12 +4759,14 @@ void GradientUtils::setPtrDiffe(Instruction *orig, Value *ptr, Value *newval, SyncScope::ID syncScope, Value *mask, ArrayRef noAlias, ArrayRef scopes) { +#ifndef NDEBUG if (auto inst = dyn_cast(ptr)) { assert(inst->getParent()->getParent() == oldFunc); } if (auto arg = dyn_cast(ptr)) { assert(arg->getParent() == oldFunc); } +#endif Value *origptr = ptr; @@ -5034,12 +5041,14 @@ llvm::Value *GradientUtils::recursiveFAdd(llvm::IRBuilder<> &B, Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, bool nullShadow) { assert(oval); +#ifndef NDEBUG if (auto inst = dyn_cast(oval)) { assert(inst->getParent()->getParent() == oldFunc); } if (auto arg = dyn_cast(oval)) { assert(arg->getParent() == oldFunc); } +#endif if (isa(oval)) { return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); @@ -6830,7 +6839,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Instruction *origTerm = origPH->getTerminator(); if (!origTerm) - llvm::errs() << *origTerm << "\n"; + llvm::errs() << *origPH << "\n"; assert(origTerm); IRBuilder<> OB(origTerm); LoadInst *tmpload = OB.CreateLoad(AT, orig_liobj, "'tmpload"); @@ -7031,6 +7040,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap, /*scope*/ nullptr, /*cache*/ false)); assert(uw->getType() == a->getType()); +#ifndef NDEBUG for (size_t i = 0; i < uw->getNumOperands(); i++) { auto op = uw->getOperand(i); if (auto arg = dyn_cast(op)) @@ -7038,6 +7048,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, else if (auto inst = dyn_cast(op)) assert(inst->getParent()->getParent() == newFunc); } +#endif available[a] = uw; unwrappedLoads.erase(cast(uw)); } @@ -7259,7 +7270,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, AllocaInst *cache = createCacheForScope( lctx, inst->getType(), inst->getName(), /*shouldFree*/ true); assert(cache); - insert_or_assign(scopeMap, (Value *&)inst, + Value *inst_tmp = inst; + insert_or_assign(scopeMap, inst_tmp, std::pair, LimitContext>( cache, lctx)); } @@ -7333,9 +7345,11 @@ void GradientUtils::branchToCorrespondingTarget( if (replacePHIs->size() == 0) return; +#ifndef NDEBUG for (auto x : *replacePHIs) { assert(targetToPreds.find(x.first) != targetToPreds.end()); } +#endif } if (targetToPreds.size() == 1) { @@ -8177,11 +8191,13 @@ void GradientUtils::forceActiveDetection() { bool GradientUtils::isConstantValue(Value *val) const { if (auto inst = dyn_cast(val)) { + (void)inst; assert(inst->getParent()->getParent() == oldFunc); return ATA->isConstantValue(TR, val); } if (auto arg = dyn_cast(val)) { + (void)arg; assert(arg->getParent() == oldFunc); return ATA->isConstantValue(TR, val); } @@ -8892,6 +8908,7 @@ void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) { // Check that the replacement doesn't already exist in the mapping // thereby resulting in a conflict. +#ifndef NDEBUG { auto found = newToOriginalFn.find(A); if (found != newToOriginalFn.end()) { @@ -8899,6 +8916,7 @@ void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) { assert(foundB == newToOriginalFn.end()); } } +#endif CacheUtility::replaceAWithB(A, B, storeInCache); } @@ -9167,6 +9185,7 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, libfunc = LibFunc_malloc; } else { bool res = TLI.getLibFunc(allocationfn, libfunc); + (void)res; assert(res && "ought find known allocation fn"); } diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 7552bec38dab..98cae4145639 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -601,11 +601,13 @@ class GradientUtils : public CacheUtility { llvm::ArrayRef diffs, llvm::IRBuilder<> &Builder, Func rule) { if (width > 1) { +#ifndef NDEBUG for (auto diff : diffs) { assert(diff); assert(llvm::cast(diff->getType())->getNumElements() == width); } +#endif llvm::Type *wrappedType = llvm::ArrayType::get(diffType, width); llvm::Value *res = llvm::UndefValue::get(wrappedType); for (unsigned int i = 0; i < getWidth(); ++i) { diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 3025b5963fcf..7d4ccead7764 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -199,8 +199,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( continue; } - auto term = oBB.getTerminator(); - assert(term); + assert(oBB.getTerminator()); auto first = oBB.begin(); auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end()); diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index 2f5275c4b6f8..71d6e0910408 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -57,6 +57,7 @@ static inline std::string to_string(BaseType t) { return "Unknown"; } assert(0 && "unknown inttype"); + return ""; } /// Convert string to BaseType @@ -72,5 +73,6 @@ static inline BaseType parseBaseType(llvm::StringRef str) { if (str == "Unknown") return BaseType::Unknown; assert(0 && "Unknown BaseType string"); + return BaseType::Unknown; } #endif diff --git a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp index 899ce46d9f9e..2376c9b23353 100644 --- a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp +++ b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp @@ -83,7 +83,6 @@ TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) { assert(0 && "There shouldn't be non-constant-size arrays in Rust"); } } - return Result; } else if (Type.getTag() == dwarf::DW_TAG_structure_type || Type.getTag() == dwarf::DW_TAG_union_type) { DINodeArray Elements = Type.getElements(); @@ -108,11 +107,11 @@ TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) { firstSubTT = !firstSubTT; } } - return Result; } else { assert(0 && "Composite types other than arrays, structs and unions are not " "supported by Rust debug info parser"); } + return Result; } TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) { @@ -134,6 +133,7 @@ TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) { assert(0 && "Derived types other than pointers and members are not " "supported by Rust debug info parser"); } + return {}; } TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) { @@ -151,6 +151,7 @@ TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) { assert(0 && "Types other than floating-points, integers, arrays, pointers, " "slices, and structs are not supported by debug info parser"); } + return {}; } bool isU8PointerType(DIType &type) { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 3204b5667086..86a5243c438a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -1860,6 +1860,7 @@ void TypeAnalyzer::visitGEPOperator(GEPOperator &gep) { MapVector VariableOffsets; bool legalOffset = collectOffset(&gep, DL, BitWidth, VariableOffsets, constOffset); + (void)legalOffset; assert(legalOffset); SmallVector, 4> idnext; @@ -5816,12 +5817,14 @@ FnTypeInfo TypeResults::getCallInfo(CallBase &CI, Function &fn) const { } TypeTree TypeResults::query(Value *val) const { +#ifndef NDEBUG if (auto inst = dyn_cast(val)) { assert(inst->getParent()->getParent() == analyzer->fntypeinfo.Function); } if (auto arg = dyn_cast(val)) { assert(arg->getParent() == analyzer->fntypeinfo.Function); } +#endif return analyzer->getAnalysis(val); } @@ -5989,8 +5992,10 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, if (auto arg = dyn_cast(val)) { llvm::errs() << *arg->getParent() << "\n"; for (auto &pair : res.analysis) { +#ifndef NDEBUG if (auto in = dyn_cast(pair.first)) assert(in->getParent()->getParent() == arg->getParent()); +#endif llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() << " int: " + to_string(res.knownIntegralValues(pair.first)) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index b69b83e5b355..3843623ca32a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -343,11 +343,13 @@ class TypeTree : public std::enable_shared_from_this { /// Whether this TypeTree contains any information bool isKnown() const { +#ifndef NDEBUG for (const auto &pair : mapping) { // we should assert here as we shouldn't keep any unknown maps for // efficiency assert(pair.second.isKnown()); } +#endif return mapping.size() != 0; } diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index f4f72f8cf52e..37edfc1a985d 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1158,6 +1158,7 @@ static inline llvm::Optional getAllocationIndexFromCall(T *op) bool b = AttrList.getAttribute("enzyme_allocator") .getValueAsString() .getAsInteger(10, res); + (void)b; assert(!b); #if LLVM_VERSION_MAJOR >= 16 return std::optional(res); @@ -1172,6 +1173,7 @@ static inline llvm::Optional getAllocationIndexFromCall(T *op) bool b = called->getFnAttribute("enzyme_allocator") .getValueAsString() .getAsInteger(10, res); + (void)b; assert(!b); #if LLVM_VERSION_MAJOR >= 16 return std::optional(res); @@ -1228,6 +1230,7 @@ static inline std::vector getDeallocationIndicesFromCall(T *op) { for (auto ind : inds) { ssize_t Result; bool b = ind.getAsInteger(10, Result); + (void)b; assert(!b); vinds.push_back(Result); } @@ -1355,10 +1358,11 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) { auto AttrList = Call->getAttributes().getAttributes( llvm::AttributeList::FunctionIndex); if (AttrList.hasAttribute("enzyme_pointermath")) { - size_t res; + size_t res = 0; bool failed = AttrList.getAttribute("enzyme_pointermath") .getValueAsString() .getAsInteger(10, res); + (void)failed; assert(!failed); V = Call->getArgOperand(res); continue; @@ -1386,10 +1390,11 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) { auto AttrList = fn->getAttributes().getAttributes( llvm::AttributeList::FunctionIndex); if (AttrList.hasAttribute("enzyme_pointermath")) { - size_t res; + size_t res = 0; bool failed = AttrList.getAttribute("enzyme_pointermath") .getValueAsString() .getAsInteger(10, res); + (void)failed; assert(!failed); V = Call->getArgOperand(res); continue; diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 39cb55384ad9..3689c77d4d74 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -444,6 +444,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) { break; } } + (void)hasInt; assert(hasInt); os << " Type* cublas_retty = nullptr;\n" @@ -482,6 +483,7 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) { break; } } + (void)foundInt; assert(foundInt && "no int type found in blas call"); os << " // fpType already given by blas type (s, d, c, z) \n" diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 84dce5f2a08c..04f63e1ad3ac 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1340,13 +1340,13 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, StringRef name = cast(lst->getValues()[0])->getValue(); if (lst->size() >= 2) { auto min = cast(lst->getValues()[1])->getValue(); - int min_int; + int min_int = 100000; min.getAsInteger(10, min_int); if (min.size() != 0 && LLVM_VERSION_MAJOR < min_int) continue; if (lst->size() >= 3) { auto max = cast(lst->getValues()[2])->getValue(); - int max_int; + int max_int = 0; max.getAsInteger(10, max_int); if (max.size() != 0 && LLVM_VERSION_MAJOR > max_int) continue; From d04f10346fc75a58eb96604db62de53d07fab6fe Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 19 Feb 2024 20:14:07 -0500 Subject: [PATCH 080/106] Add ResultTypes tablegen --- enzyme/Enzyme/MLIR/Implementations/Common.td | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 83b89f7cf0a4..33d9f12c2f37 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -87,6 +87,8 @@ class ConstantFP : Ope string type = type_; } +def ResultTypes : GlobalExprgetResultTypes()">; + class ArithInst : Inst; class MathInst : Inst; From 1ea2e0be2db660ee54bde212194f7446b5479e8d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 20 Feb 2024 09:31:20 -0500 Subject: [PATCH 081/106] TypeTree speedup shiftindicies (#1744) --- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 173 ++++++++++++++++++++++---- 1 file changed, 152 insertions(+), 21 deletions(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index 3843623ca32a..2cb4a7506c94 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -724,17 +724,51 @@ class TypeTree : public std::enable_shared_from_this { } /// Replace mappings in the range in [offset, offset+maxSize] with those in - // [addOffset, addOffset + maxSize]. In other worse, select all mappings in + // [addOffset, addOffset + maxSize]. In other words, select all mappings in // [offset, offset+maxSize] then add `addOffset` TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset, const int maxSize, size_t addOffset = 0) const { + + // If we have no terms 1+ layer deep return the current result as a shift + // won't change anything. This also makes the latercode simpler as it + // can assume at least a first index exists. + if (minIndices.size() == 0) + return *this; + + // If we have no size in return, simply return an empty type tree. Again + // this simplifies later code which can assume that a minus one expantion + // will always result in an added variable (which would not be the case + // on a size == 0). + if (maxSize == 0) + return TypeTree(); + TypeTree Result; + // The normal orIn / insert methods do collision checking, which is slow + // (and presently O(n)). This is because an expansion of a -1 which could + // conflict with a fixed value. Consider calling this + // ShiftIndicies(offset=0, maxSize=2, addOffset=0, tt={[-1]:Integer, + // [1]:Anything}) the -1 would expand to [0]:Int, [1]:Int, which would need + // to be merged with [1]:Anything + // + // The only possible values which can cause a conflict are minus -1's. + // As a result, we start with a fast insertion (aka without check) of + // non-expanded values, since they just do a literal shift which needs no + // extra checking, besides bounds checks. + // + // Since we're doing things manually, we also need to manually preserve TT + // invariants. Specifically, TT limits all values to have offsets < + // MAX_OFFSET, unless it is the smallest offset at that depth. (e.g. so we + // can still hava typetree {[123456]:Int}, even if limit is 100). + // + // First compute the minimum 0th index to be kept. + Result.minIndices.resize(minIndices.size(), INT_MAX); + for (const auto &pair : mapping) { if (pair.first.size() == 0) { if (pair.second == BaseType::Pointer || pair.second == BaseType::Anything) { - Result.insert(pair.first, pair.second); + Result.mapping.emplace(pair.first, pair.second); continue; } @@ -743,55 +777,152 @@ class TypeTree : public std::enable_shared_from_this { llvm_unreachable("ShiftIndices called on a nonpointer/anything"); } - std::vector next(pair.first); + int next0 = pair.first[0]; + + if (next0 == -1) { + if (maxSize == -1) { + // Max size does not clip the next index + + // If we have a follow up offset add, we lose the -1 since we only + // represent [0, inf) with -1 not the [addOffset, inf) required here + if (addOffset != 0) { + next0 = addOffset; + } + + } else { + // We're going to insert addOffset + 0...maxSize so the new minIndex + // is addOffset + Result.minIndices[0] = addOffset; + for (size_t i = 1, sz = pair.first.size(); i < sz; i++) + if (pair.first[i] < Result.minIndices[i]) + Result.minIndices[i] = pair.first[i]; + continue; + } + } else { + // Too small for range + if (next0 < offset) { + continue; + } + next0 -= offset; + + if (maxSize != -1) { + if (next0 >= maxSize) + continue; + } + + next0 += addOffset; + } + if (next0 < Result.minIndices[0]) + Result.minIndices[0] = next0; + for (size_t i = 1, sz = pair.first.size(); i < sz; i++) + if (pair.first[i] < Result.minIndices[i]) + Result.minIndices[i] = pair.first[i]; + } + + // Max depth of actual inserted values + size_t maxInsertedDepth = 0; + + // Insert all + for (const auto &pair : mapping) { + if (pair.first.size() == 0) + continue; - if (next[0] == -1) { + int next0 = pair.first[0]; + + if (next0 == -1) { if (maxSize == -1) { // Max size does not clip the next index // If we have a follow up offset add, we lose the -1 since we only // represent [0, inf) with -1 not the [addOffset, inf) required here if (addOffset != 0) { - next[0] = addOffset; + next0 = addOffset; } } else { - // This needs to become 0...maxSize as seen below + // This needs to become 0...maxSize handled separately as it is the + // only insertion that could have collisions + continue; } } else { // Too small for range - if (next[0] < offset) { + if (next0 < offset) { continue; } - next[0] -= offset; + next0 -= offset; if (maxSize != -1) { - if (next[0] >= maxSize) + if (next0 >= maxSize) continue; } - next[0] += addOffset; + next0 += addOffset; } - size_t chunk = 1; - auto op = operator[]({pair.first[0]}); - if (auto flt = op.isFloat()) { - chunk = dl.getTypeSizeInBits(flt) / 8; - } else if (op == BaseType::Pointer) { - chunk = dl.getPointerSizeInBits() / 8; + // If after moving this would not merit being kept for being a min index + // or being within the max type offset, skip it. + if (next0 > MaxTypeOffset) { + bool minIndex = next0 == Result.minIndices[0]; + if (!minIndex) + for (size_t i = 1; i < pair.first.size(); i++) { + if (pair.first[i] == Result.minIndices[i]) { + minIndex = true; + break; + } + } + if (!minIndex) + continue; } - if (next[0] == -1 && maxSize != -1) { + std::vector next(pair.first); + next[0] = next0; + Result.mapping.emplace(next, pair.second); + if (next.size() > maxInsertedDepth) + maxInsertedDepth = next.size(); + } + + // Insert and expand the minus one, if needed + if (maxSize != -1) + for (const auto &pair : mapping) { + if (pair.first.size() == 0) + continue; + if (pair.first[0] != -1) + continue; + + size_t chunk = 1; + std::vector next(pair.first); + auto op = operator[]({next[0]}); + if (auto flt = op.isFloat()) { + chunk = dl.getTypeSizeInBits(flt) / 8; + } else if (op == BaseType::Pointer) { + chunk = dl.getPointerSizeInBits() / 8; + } auto offincr = (chunk - offset % chunk) % chunk; + bool inserted = false; for (int i = offincr; i < maxSize; i += chunk) { next[0] = i + addOffset; - Result.orIn(next, pair.second); + ConcreteType prev(pair.second); + // We can use faster checks here, since we know there can be no + // -1's that we would conflict with, only conflicts from previous + // fixed value insertions. + auto found = Result.mapping.find(next); + if (found != Result.mapping.end()) { + // orIn returns if changed, update the value in the map if so + // with the new value. + if (prev.orIn(found->second, /*pointerIntSame*/ false)) + found->second = prev; + } else { + Result.mapping.emplace(next, pair.second); + } + inserted = true; } - } else { - Result.orIn(next, pair.second); + if (inserted && next.size() > maxInsertedDepth) + maxInsertedDepth = next.size(); } - } + // Resize minIndices down if we dropped any higher-depth indices for being + // out of scope. + Result.minIndices.resize(maxInsertedDepth); return Result; } From e8ca2b1de3b770c767145d027b357bed97178bb0 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 20 Feb 2024 16:03:49 +0100 Subject: [PATCH 082/106] Split tests into two packages in Bazel build. (#1741) This separates MLIR tests and regular Enzyme integration tests into independent packages, making sure only a subset of targets is rebuilt and retested when requested to decrease overhead. --- enzyme/BUILD | 53 ++--------------------------------- enzyme/test/BUILD | 32 +++++++++++++++++++++ enzyme/test/Integration/BUILD | 29 +++++++++++++++++++ enzyme/test/MLIR/BUILD | 25 +++++++++++++++++ 4 files changed, 88 insertions(+), 51 deletions(-) create mode 100644 enzyme/test/BUILD create mode 100644 enzyme/test/Integration/BUILD create mode 100644 enzyme/test/MLIR/BUILD diff --git a/enzyme/BUILD b/enzyme/BUILD index 7c78c75a4051..6f4459f0ebd2 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -1,7 +1,5 @@ -load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@llvm-project//llvm:tblgen.bzl", "gentbl") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") licenses(["notice"]) @@ -640,52 +638,5 @@ cc_binary( ], ) -# Generates lit config input file by applying path placeholder substitutions -# similar to the configure_lit_site_cfg CMake macro. -expand_template( - name = "lit_site_cfg_py", - testonly = True, - out = "test/lit.site.cfg.py", - substitutions = { - "@LLVM_VERSION_MAJOR@": "18", - "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", - "@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@ENZYME_SOURCE_DIR@": "", - "@ENZYME_BINARY_DIR@": "", - "@TARGET_TRIPLE@": "", - "@TARGETS_TO_BUILD@": "ALL", - "@LLVM_SHLIBEXT@": ".so", - }, - template = "test/lit.site.cfg.py.in", - visibility = ["//visibility:private"], -) - -[ - lit_test( - name = "%s.test" % src, - srcs = [src], - data = [ - ":enzyme-clang", - ":enzyme-clang++", - ":enzyme-opt", - ":enzymemlir-opt", - ":test/lit.cfg.py", - ":test/lit.site.cfg.py", - "@llvm-project//clang:builtin_headers_gen", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:count", - "@llvm-project//llvm:lli", - "@llvm-project//llvm:not", - ] + glob(["test/**/*.h"]), - ) - for src in glob( - [ - "test/**/*.mlir", - "test/Integration/**/*.c", - "test/Integration/**/.cpp", - ], - exclude = ["test/**/*omp*.c"], - ) -] +exports_files(["run_lit.sh"]) + diff --git a/enzyme/test/BUILD b/enzyme/test/BUILD new file mode 100644 index 000000000000..47143966810d --- /dev/null +++ b/enzyme/test/BUILD @@ -0,0 +1,32 @@ +# Enzyme tests. + +load("@llvm-project//llvm:lit_test.bzl", "package_path") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +# Generates lit config input file by applying path placeholder substitutions +# similar to the configure_lit_site_cfg CMake macro. +expand_template( + name = "lit_site_cfg_py", + testonly = True, + out = "lit.site.cfg.py", + substitutions = { + "@LLVM_VERSION_MAJOR@": "18", + "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", + "@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@ENZYME_SOURCE_DIR@": "", + "@ENZYME_BINARY_DIR@": "", + "@TARGET_TRIPLE@": "", + "@TARGETS_TO_BUILD@": "ALL", + "@LLVM_SHLIBEXT@": ".so", + }, + template = "lit.site.cfg.py.in", + visibility = [":__subpackages__"], +) + +exports_files( + ["lit.cfg.py"], + visibility = [":__subpackages__"], +) + diff --git a/enzyme/test/Integration/BUILD b/enzyme/test/Integration/BUILD new file mode 100644 index 000000000000..ab14f8fca82e --- /dev/null +++ b/enzyme/test/Integration/BUILD @@ -0,0 +1,29 @@ +# Enzyme integration tests. + +load("@llvm-project//llvm:lit_test.bzl", "lit_test") + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + "//:enzyme-clang", + "//:enzyme-clang++", + "//:enzyme-opt", + "//test:lit.cfg.py", + "//test:lit.site.cfg.py", + "@llvm-project//clang:builtin_headers_gen", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["**/*.h"]), + ) + for src in glob( + [ + "**/*.c", + "**/.cpp", + ], + exclude = ["**/*omp*.c"], + ) +] diff --git a/enzyme/test/MLIR/BUILD b/enzyme/test/MLIR/BUILD new file mode 100644 index 000000000000..0af1496b3a5c --- /dev/null +++ b/enzyme/test/MLIR/BUILD @@ -0,0 +1,25 @@ +# MLIR-specific tests for Enzyme. + +load("@llvm-project//llvm:lit_test.bzl", "lit_test") + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + "//:enzymemlir-opt", + "//test:lit.cfg.py", + "//test:lit.site.cfg.py", + "@llvm-project//clang:builtin_headers_gen", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["**/*.h"]), + ) + for src in glob( + [ + "**/*.mlir", + ], + ) +] From f600e2e6dfdb4312a36e7f9043f2d4231b4f1795 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 20 Feb 2024 17:20:02 +0100 Subject: [PATCH 083/106] silence warnings in bazel (#1745) * Split tests into two packages in Bazel build. This separates MLIR tests and regular Enzyme integration tests into independent packages, making sure only a subset of targets is rebuilt and retested when requested to decrease overhead. * silence warnings in bazel There are a bunch of these in the codebase --- enzyme/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/BUILD b/enzyme/BUILD index 6f4459f0ebd2..e582b435d387 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -161,6 +161,8 @@ cc_library( "-DENZYME_VERSION_MAJOR=0", "-DENZYME_VERSION_MINOR=0", "-DENZYME_VERSION_PATCH=79", + "-Wno-unused-variable", + "-Wno-return-type", ], data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], From 1beb98b51442d50652eaa3ffb9574f4720d611f1 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Tue, 20 Feb 2024 20:42:44 -0500 Subject: [PATCH 084/106] [ActivityAnalysis] Remove isConstantValue call in activity analysis (#1608) * Remove cop2 * Add integration test * Add test * Update enzyme/test/ActivityAnalysis/integration.ll * Update test * Update test * Add missing return * Format and test * Back gt * Test fix trial * Lower test llvm to 15 * Update enzyme/test/ActivityAnalysis/integration.ll * Update activity printer for opaque type * update -> * Update activity analysis * Update if/elif llvm * Update llvm versioning --- enzyme/Enzyme/ActivityAnalysis.cpp | 8 +- enzyme/Enzyme/ActivityAnalysisPrinter.cpp | 38 +++++---- enzyme/test/ActivityAnalysis/integration.ll | 93 +++++++++++++++++++++ 3 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 enzyme/test/ActivityAnalysis/integration.ll diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 274c2fcea183..0f8ee3c6161d 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -2073,14 +2073,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { << "\n"; if (auto SI = dyn_cast(I)) { bool cop = !Hypothesis->isConstantValue(TR, SI->getValueOperand()); - bool cop2 = !Hypothesis->isConstantValue(TR, SI->getPointerOperand()); + // bool cop2 = !Hypothesis->isConstantValue(TR, + // SI->getPointerOperand()); if (EnzymePrintActivity) - llvm::errs() << " -- store potential activity: " << (int)cop << "," - << (int)cop2 << "," + llvm::errs() << " -- store potential activity: " << (int)cop << " - " << *SI << " of " << " Val=" << *Val << "\n"; potentialStore = I; - if (cop && cop2) + if (cop) // && cop2) potentiallyActiveStore = SI; } else if (auto MTI = dyn_cast(I)) { bool cop = !Hypothesis->isConstantValue(TR, MTI->getArgOperand(1)); diff --git a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp b/enzyme/Enzyme/ActivityAnalysisPrinter.cpp index a69d7c9e56e8..bfca7860fc45 100644 --- a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp +++ b/enzyme/Enzyme/ActivityAnalysisPrinter.cpp @@ -89,14 +89,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) { if (a.getType()->isFPOrFPVectorTy()) { dt = ConcreteType(a.getType()->getScalarType()); } else if (a.getType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR >= 17 -#else - auto et = a.getType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR < 17 +#if LLVM_VERSION_MAJOR >= 13 + if (a.getContext().supportsTypedPointers()) { +#endif + auto et = a.getType()->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 13 } +#endif #endif } else if (a.getType()->isIntOrIntVectorTy()) { dt = ConcreteType(BaseType::Integer); @@ -113,14 +118,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) { if (F.getReturnType()->isFPOrFPVectorTy()) { dt = ConcreteType(F.getReturnType()->getScalarType()); } else if (F.getReturnType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR >= 17 -#else - auto et = F.getReturnType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR < 17 +#if LLVM_VERSION_MAJOR >= 13 + if (F.getContext().supportsTypedPointers()) { +#endif + auto et = F.getReturnType()->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 13 } +#endif #endif } else if (F.getReturnType()->isIntOrIntVectorTy()) { dt = ConcreteType(BaseType::Integer); diff --git a/enzyme/test/ActivityAnalysis/integration.ll b/enzyme/test/ActivityAnalysis/integration.ll new file mode 100644 index 000000000000..28ce009a9b5b --- /dev/null +++ b/enzyme/test/ActivityAnalysis/integration.ll @@ -0,0 +1,93 @@ +; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %OPnewLoadEnzyme -passes="print-activity-analysis" -activity-analysis-func=f.preprocess -S | FileCheck %s; fi + +declare void @free(ptr) + +declare ptr @malloc(i64) + +; This function just returns 2*input, its derivate should be 2.0. +define void @f.preprocess(ptr %param, i64 %mallocsize, ptr %res) { + + ; arithmetic block, changing anything here makes the bug go away + %buffer1 = call ptr @malloc(i64 %mallocsize) + %tmp = call ptr @malloc(i64 72) + %ptrtoint = ptrtoint ptr %tmp to i64 + %and = and i64 %ptrtoint, -64 + %inttoptr = inttoptr i64 %and to ptr + %loadarg = load double, ptr %param + %storedargmul = fmul double %loadarg, 4.000000e+00 + store double %storedargmul, ptr %inttoptr + call void @free(ptr %tmp) + store double %storedargmul, ptr %buffer1 + + ; prep arg 0 by setting the aligned pointer to the input + %arg0 = alloca { ptr, ptr, i64 } + %arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1 + store ptr %param, ptr %arg0_aligned + + ; prep arg 1 by setting the aligned pointer to buffer1 + %arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] } + %arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1 + store ptr %buffer1, ptr %arg1_aligned + + ; prep arg 2 by setting the aligned pointer to buffer2 + %arg2 = alloca { ptr, ptr, i64 } + %arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1 + %buffer2 = call ptr @malloc(i64 8) + store ptr %buffer2, ptr %arg2_aligned + + ; nested call, required for bug + call void @nested(ptr %arg0, ptr %arg1, ptr %arg2) + + ; return a result from this function, needs to be positioned after arithmetic block for bug + %x = load double, ptr %param + %y = fmul double %x, 2.0 + store double %y, ptr %res + + ret void +} + +; Identity function, 2nd argument required for bug (but not used) +define void @nested(ptr %arg0, ptr %arg1, ptr %arg2) { + + ; load aligned pointer from %arg0 & load argument value + %loadarg = load { ptr, ptr, i64 }, ptr %arg0 + %extractarg = extractvalue { ptr, ptr, i64 } %loadarg, 1 + %loadextractarg = load double, ptr %extractarg + + ; load aligned pointer from %arg2 & store result value + %loadarg2 = load { ptr, ptr, i64 }, ptr %arg2 + %extractarg2 = extractvalue { ptr, ptr, i64 } %loadarg2, 1 + store double %loadextractarg, ptr %extractarg2 + + ret void +} + +; CHECK: ptr %param: icv:0 +; CHECK-NEXT: i64 %mallocsize: icv:1 +; CHECK-NEXT: ptr %res: icv:0 + +; CHECK: %buffer1 = call ptr @malloc(i64 %mallocsize): icv:0 ici:1 +; CHECK-NEXT: %tmp = call ptr @malloc(i64 72): icv:1 ici:1 +; CHECK-NEXT: %ptrtoint = ptrtoint ptr %tmp to i64: icv:1 ici:1 +; CHECK-NEXT: %and = and i64 %ptrtoint, -64: icv:1 ici:1 +; CHECK-NEXT: %inttoptr = inttoptr i64 %and to ptr: icv:1 ici:1 +; CHECK-NEXT: %loadarg = load double, ptr %param, align 8: icv:0 ici:0 +; CHECK-NEXT: %storedargmul = fmul double %loadarg, 4.000000e+00: icv:0 ici:0 +; CHECK-NEXT: store double %storedargmul, ptr %inttoptr, align 8: icv:1 ici:1 +; CHECK-NEXT: call void @free(ptr %tmp): icv:1 ici:1 +; CHECK-NEXT: store double %storedargmul, ptr %buffer1, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg0 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: store ptr %param, ptr %arg0_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: store ptr %buffer1, ptr %arg1_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg2 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: %buffer2 = call ptr @malloc(i64 8): icv:0 ici:1 +; CHECK-NEXT: store ptr %buffer2, ptr %arg2_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: call void @nested(ptr %arg0, ptr %arg1, ptr %arg2): icv:1 ici:0 +; CHECK-NEXT: %x = load double, ptr %param, align 8: icv:0 ici:0 +; CHECK-NEXT: %y = fmul double %x, 2.000000e+00: icv:0 ici:0 +; CHECK-NEXT: store double %y, ptr %res, align 8: icv:1 ici:0 +; CHECK-NEXT: ret void: icv:1 ici:1 From d4cb8b3dd921ead3a5f3c8e29fd071077402c36b Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 21 Feb 2024 09:51:07 -0500 Subject: [PATCH 085/106] TypeAnalysis permit passing info as fn parm attrs (#1746) --- enzyme/Enzyme/AdjointGenerator.h | 53 +++++++------ enzyme/Enzyme/EnzymeLogic.cpp | 36 ++++++++- enzyme/Enzyme/FunctionUtils.cpp | 22 +++++- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 67 ++++++++++++++-- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 87 ++++++++++++++++++++- 5 files changed, 231 insertions(+), 34 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index dbcd0947d342..9f1d3cd65dfc 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4757,10 +4757,11 @@ class AdjointGenerator : public llvm::InstVisitor { // call.getParamAttr(i, Attribute::StructRet).getValueAsType())); #endif } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - structAttrs[args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); + } for (auto ty : PrimalParamAttrsToPreserve) if (call.getAttributes().hasParamAttr(i, ty)) { auto attr = call.getAttributes().getParamAttr(i, ty); @@ -4815,15 +4816,16 @@ class AdjointGenerator : public llvm::InstVisitor { structAttrs[args.size()].push_back(attr); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - if (gutils->getWidth() == 1) { - structAttrs[args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } else { - structAttrs[args.size()].push_back( - Attribute::get(call.getContext(), "enzyme_sret_v")); + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + if (gutils->getWidth() == 1) { + structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); + } else if (attr == "enzymejl_returnRoots") { + structAttrs[args.size()].push_back( + Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + } } - } if (call.paramHasAttr(i, Attribute::StructRet)) { if (gutils->getWidth() == 1) { structAttrs[args.size()].push_back( @@ -5050,10 +5052,11 @@ class AdjointGenerator : public llvm::InstVisitor { if (call.isByValArgument(i)) { preByVal[pre_args.size()] = call.getParamByValType(i); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - structAttrs[pre_args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + structAttrs[pre_args.size()].push_back(call.getParamAttr(i, attr)); + } if (call.paramHasAttr(i, Attribute::StructRet)) { structAttrs[pre_args.size()].push_back( #if LLVM_VERSION_MAJOR >= 12 @@ -5146,15 +5149,17 @@ class AdjointGenerator : public llvm::InstVisitor { structAttrs[pre_args.size()].push_back(attr); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - if (gutils->getWidth() == 1) { - structAttrs[pre_args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } else { - structAttrs[pre_args.size()].push_back( - Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + if (gutils->getWidth() == 1) { + structAttrs[pre_args.size()].push_back( + call.getParamAttr(i, attr)); + } else if (attr == "enzymejl_returnRoots") { + structAttrs[pre_args.size()].push_back( + Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + } } - } if (call.paramHasAttr(i, Attribute::StructRet)) { if (gutils->getWidth() == 1) { structAttrs[pre_args.size()].push_back( diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 0767b3e52529..76fc71921654 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -82,6 +82,7 @@ #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex +#define getAttribute getAttributeAtIndex #define removeAttribute removeAttributeAtIndex #endif @@ -2784,7 +2785,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( NewF->addParamAttr(attrIndex, Attribute::NoAlias); } for (auto name : {"enzyme_sret", "enzyme_sret_v", "enzymejl_returnRoots", - "enzymejl_returnRoots_v"}) + "enzymejl_returnRoots_v", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) if (nf->getAttributes().hasParamAttr(attrIndex, name)) { NewF->addParamAttr(attrIndex, nf->getAttributes().getParamAttr(attrIndex, name)); @@ -2796,6 +2798,38 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( ++attrIndex; } +#if LLVM_VERSION_MAJOR >= 14 + for (auto attr : {"enzyme_ta_norecur"}) + if (nf->getAttributes().hasAttributeAtIndex(AttributeList::FunctionIndex, + attr)) { + NewF->addFnAttr( + nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (nf->getAttributes().hasAttributeAtIndex(AttributeList::ReturnIndex, + attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } +#else + for (auto attr : {"enzyme_ta_norecur"}) + if (nf->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) { + NewF->addFnAttr( + nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (nf->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } +#endif + SmallVector Returns; #if LLVM_VERSION_MAJOR >= 13 CloneFunctionInto(NewF, nf, VMap, CloneFunctionChangeType::LocalChangesOnly, diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 218ca9549dfc..3e2ade41b1c9 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -2180,8 +2180,24 @@ Function *PreProcessCache::CloneFunctionWithReturns( VMapO->getMDMap() = VMap.getMDMap(); } + for (auto attr : {"enzyme_ta_norecur"}) + if (F->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) { + NewF->addAttribute( + AttributeList::FunctionIndex, + F->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + F->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } + bool hasPtrInput = false; unsigned ii = 0, jj = 0; + for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) { if (F->hasParamAttribute(ii, Attribute::StructRet)) { NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret")); @@ -2204,7 +2220,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // Attribute::ElementType)); #endif } - for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + for (auto attr : + {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) if (F->getAttributes().hasParamAttr(ii, attr)) { NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr)); for (auto ty : PrimalParamAttrsToPreserve) @@ -2250,7 +2267,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( NewF->addParamAttr(jj + 1, attr); } - for (auto attr : {"enzymejl_parmtype", "enzymejl_parmtype_ref"}) + for (auto attr : + {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) if (F->getAttributes().hasParamAttr(ii, attr)) { if (width == 1) NewF->addParamAttr(jj + 1, diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 86a5243c438a..2515169ae0e4 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -61,6 +61,12 @@ #include +#if LLVM_VERSION_MAJOR >= 14 +#define getAttribute getAttributeAtIndex +#define hasAttribute hasAttributeAtIndex +#define addAttribute addAttributeAtIndex +#endif + using namespace llvm; extern "C" { @@ -1207,18 +1213,57 @@ void TypeAnalyzer::considerTBAA() { } if (CallBase *call = dyn_cast(&I)) { +#if LLVM_VERSION_MAJOR >= 14 + size_t num_args = call->arg_size(); +#else + size_t num_args = call->getNumArgOperands(); +#endif + + if (call->getAttributes().hasAttribute(AttributeList::ReturnIndex, + "enzyme_type")) { + auto attr = call->getAttributes().getAttribute( + AttributeList::ReturnIndex, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call, TT, call); + } + for (size_t i = 0; i < num_args; i++) { + if (call->getAttributes().hasParamAttr(i, "enzyme_type")) { + auto attr = call->getAttributes().getParamAttr(i, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call->getArgOperand(i), TT, call); + } + } + Function *F = call->getCalledFunction(); + + if (F) { + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, + "enzyme_type")) { + auto attr = F->getAttributes().getAttribute( + AttributeList::ReturnIndex, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call, TT, call); + } + size_t f_num_args = F->arg_size(); + for (size_t i = 0; i < f_num_args; i++) { + if (F->getAttributes().hasParamAttr(i, "enzyme_type")) { + auto attr = F->getAttributes().getParamAttr(i, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call->getArgOperand(i), TT, call); + } + } + } + if (auto castinst = dyn_cast(call->getCalledOperand())) { if (castinst->isCast()) if (auto fn = dyn_cast(castinst->getOperand(0))) { F = fn; } } -#if LLVM_VERSION_MAJOR >= 14 - size_t num_args = call->arg_size(); -#else - size_t num_args = call->getNumArgOperands(); -#endif if (F && F->getName().contains("__enzyme_float")) { assert(num_args == 1 || num_args == 2); assert(call->getArgOperand(0)->getType()->isPointerTy()); @@ -4356,9 +4401,16 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } } + if (call.hasFnAttr("enzyme_ta_norecur")) + return; + Function *ci = getFunctionFromCall(&call); if (ci) { + if (ci->getAttributes().hasAttribute(AttributeList::FunctionIndex, + "enzyme_ta_norecur")) + return; + StringRef funcName = getFuncNameFromCall(&call); auto blasMetaData = extractBLAS(funcName); @@ -4367,6 +4419,9 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { #include "BlasTA.inc" } + // Manual TT specification is non-interprocedural and already handled once + // at the start. + // When compiling Enzyme against standard LLVM, and not Intel's // modified version of LLVM, the intrinsic `llvm.intel.subscript` is // not fully understood by LLVM. One of the results of this is that the @@ -5596,7 +5651,7 @@ bool TypeAnalyzer::mustRemainInteger(Value *val, bool *returned) { FnTypeInfo TypeAnalyzer::getCallInfo(CallBase &call, Function &fn) { FnTypeInfo typeInfo(&fn); - int argnum = 0; + size_t argnum = 0; for (auto &arg : fn.args()) { if (argnum >= call.arg_size()) { typeInfo.Arguments.insert( diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index 2cb4a7506c94..bfcd9f488ab5 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -83,6 +83,92 @@ class TypeTree : public std::enable_shared_from_this { } } + static TypeTree parse(llvm::StringRef str, llvm::LLVMContext &ctx) { + using namespace llvm; + assert(str[0] == '{'); + str = str.substr(1); + + TypeTree Result; + while (true) { + while (str[0] == ' ') + str = str.substr(1); + if (str[0] == '}') + break; + + assert(str[0] == '['); + str = str.substr(1); + + std::vector idxs; + while (true) { + while (str[0] == ' ') + str = str.substr(1); + if (str[0] == ']') { + str = str.substr(1); + break; + } + + int idx; + bool failed = str.consumeInteger(10, idx); + (void)failed; + assert(!failed); + idxs.push_back(idx); + + while (str[0] == ' ') + str = str.substr(1); + + if (str[0] == ',') { + str = str.substr(1); + } + } + + while (str[0] == ' ') + str = str.substr(1); + + assert(str[0] == ':'); + str = str.substr(1); + + while (str[0] == ' ') + str = str.substr(1); + + auto endval = str.find(','); + auto endval2 = str.find('}'); + auto endval3 = str.find(' '); + + if (endval2 != StringRef::npos && + (endval == StringRef::npos || endval2 < endval)) + endval = endval2; + if (endval3 != StringRef::npos && + (endval == StringRef::npos || endval3 < endval)) + endval = endval3; + assert(endval != StringRef::npos); + + auto tystr = str.substr(0, endval); + str = str.substr(endval); + + ConcreteType CT(tystr, ctx); + Result.mapping.emplace(idxs, CT); + if (Result.minIndices.size() < idxs.size()) { + for (size_t i = Result.minIndices.size(), end = idxs.size(); i < end; + ++i) { + Result.minIndices.push_back(idxs[i]); + } + } + for (size_t i = 0, end = idxs.size(); i < end; ++i) { + if (idxs[i] < Result.minIndices[i]) + Result.minIndices[i] = idxs[i]; + } + + while (str[0] == ' ') + str = str.substr(1); + + if (str[0] == ',') { + str = str.substr(1); + } + } + + return Result; + } + /// Utility helper to lookup the mapping const ConcreteTypeMapType &getMapping() const { return mapping; } @@ -1228,7 +1314,6 @@ class TypeTree : public std::enable_shared_from_this { if (found != RHS.mapping.end()) { RightCT = found->second; } - bool SubLegal = true; changed |= CT.binopIn(SubLegal, RightCT, Op); if (!SubLegal) { From 638ac37a037d1b63c401f06ead22771945590854 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 21 Feb 2024 10:51:26 -0500 Subject: [PATCH 086/106] TypeAnalysis speed up Canonicalize (#1747) * TypeAnalysis speed up Canonicalize * replace old impl --- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 45 ++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index bfcd9f488ab5..22a4c25d9361 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -702,22 +702,22 @@ class TypeTree : public std::enable_shared_from_this { staging[next][pair.second].insert(pair.first[0]); } - mapping.clear(); + // TypeTree mappings which did not get combined + std::map, ConcreteType> unCombinedToAdd; - for (auto &pair : staging) { + // TypeTree mappings which did get combined into an outer -1 + std::map, ConcreteType> combinedToAdd; + + for (const auto &pair : staging) { auto &pnext = pair.first; - for (auto &pair2 : pair.second) { + for (const auto &pair2 : pair.second) { auto dt = pair2.first; const auto &set = pair2.second; - // llvm::errs() << " - set: {"; - // for(auto s : set) llvm::errs() << s << ", "; - // llvm::errs() << "} len=" << len << "\n"; - - bool legalCombine = set.count(-1); + bool legalCombine = false; // See if we can canonicalize the outermost index into a -1 - if (!legalCombine) { + if (!set.count(-1)) { size_t chunk = 1; if (pnext.size() > 0) { chunk = dl.getPointerSizeInBits() / 8; @@ -745,15 +745,38 @@ class TypeTree : public std::enable_shared_from_this { next.push_back(v); if (legalCombine) { - insert(next, dt, /*intsAreLegalPointerSub*/ true); + combinedToAdd.emplace(next, dt); } else { for (auto e : set) { next[0] = e; - insert(next, dt); + unCombinedToAdd.emplace(next, dt); } } } } + + // If we combined nothing, just return since there are no + // changes. + if (combinedToAdd.size() == 0) { + return; + } + + // Non-combined ones do not conflict, since they were already in + // a TT which we can assume contained no conflicts. + mapping = std::move(unCombinedToAdd); + minIndices[0] = -1; + + // Fusing several terms into a minus one can create a conflict + // if the prior minus one was already in the map + // time, or also generated by fusion. + // E.g. {-1:Anything, [0]:Pointer} on 8 -> create a [-1]:Pointer + // which conflicts + // Alternatively [-1,-1,-1]:Pointer, and generated a [-1,0,-1] fusion + for (const auto &pair : combinedToAdd) { + insert(pair.first, pair.second); + } + + return; } /// Keep only pointers (or anything's) to a repeated value (represented by -1) From 72342461d2e37cfdc9765084f4860f1d2b6da7d4 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 26 Feb 2024 20:35:07 +0100 Subject: [PATCH 087/106] Update enzyme-ci.yml (#1754) --- .github/workflows/enzyme-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 4a1ef073160d..fcb94bb604f3 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -61,7 +61,7 @@ jobs: strategy: fail-fast: false matrix: - llvm: ["11", "12", "13", "14", "15"] + llvm: ["12", "13", "14", "15", "16"] build: ["Release", "Debug"] # "RelWithDebInfo" timeout-minutes: 30 From 9884107d3df26d1e8d826c4dada22f7f1971cb0b Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 26 Feb 2024 20:35:24 +0100 Subject: [PATCH 088/106] Create dependabot.yml (#1755) --- .github/dependabot.yml | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .github/dependabot.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..edb6037854cc --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: +# Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" From 26c43370e4d56730d0b00e4e3b2705b28f7df375 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 26 Feb 2024 20:35:59 +0100 Subject: [PATCH 089/106] Export compile_commands.json to enable clangd support (#1753) --- .devcontainer/devcontainer.json | 6 ++++-- .gitignore | 1 + enzyme/CMakeLists.txt | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d201b73d532d..a9b75aefae89 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,5 @@ // available ubuntu versions: [20, 22] -// available llvm versions: [9, 10, 11, 12, 13, 14, 15] +// available llvm versions: [11, 12, 13, 14, 15, 16, 17, 18] { "name": "Enzyme", "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-22-llvm-16:latest", @@ -14,7 +14,9 @@ "customizations": { "vscode": { "extensions": [ - "ms-vscode.cpptools-extension-pack" + "llvm-vs-code-extensions.vscode-clangd", + "BazelBuild.vscode-bazel", + "twxs.cmake" ] } } diff --git a/.gitignore b/.gitignore index aad84254cf78..33a6c3f8cdd8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ enzyme/benchmarks/ReverseMode/*/*.o enzyme/benchmarks/ReverseMode/*/*.exe enzyme/benchmarks/ReverseMode/*/results.txt enzyme/benchmarks/ReverseMode/*/results.json +.cache diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 32d09c19e302..077f4f3a9554 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -13,6 +13,7 @@ add_definitions(-DENZYME_VERSION_MAJOR=${ENZYME_MAJOR_VERSION}) add_definitions(-DENZYME_VERSION_MINOR=${ENZYME_MINOR_VERSION}) add_definitions(-DENZYME_VERSION_PATCH=${ENZYME_PATCH_VERSION}) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull") SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") From 2dec42bf7df5015da2e78a1dd3f1de6e5b2db6ce Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 00:48:11 +0100 Subject: [PATCH 090/106] Bump actions/cache from 3 to 4 (#1757) Bumps [actions/cache](https://github.com/actions/cache) from 3 to 4. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/enzyme-mlir.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index b4c581a3a720..f36ed1671c46 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -46,7 +46,7 @@ jobs: - name: Cache MLIR id: cache-mlir - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: llvm-project/mlir-build key: ${{ matrix.llbuild }}-${{ matrix.os }}-mlir-${{ steps.mlir-commit.outputs.sha_short }} From d1b26fd50085f47f9b0a93071db5716de69a664c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:49:15 +0000 Subject: [PATCH 091/106] Bump mattnotmitt/doxygen-action from 1.9.2 to 1.9.8 (#1758) Bumps [mattnotmitt/doxygen-action](https://github.com/mattnotmitt/doxygen-action) from 1.9.2 to 1.9.8. - [Release notes](https://github.com/mattnotmitt/doxygen-action/releases) - [Commits](https://github.com/mattnotmitt/doxygen-action/compare/v1.9.2...v1.9.8) --- updated-dependencies: - dependency-name: mattnotmitt/doxygen-action dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/doxygen.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/doxygen.yml b/.github/workflows/doxygen.yml index 659ed659589a..8b6278c99a26 100644 --- a/.github/workflows/doxygen.yml +++ b/.github/workflows/doxygen.yml @@ -11,7 +11,7 @@ jobs: steps: - uses: actions/checkout@v3 - - uses: mattnotmitt/doxygen-action@v1.9.2 + - uses: mattnotmitt/doxygen-action@v1.9.8 with: working-directory: 'enzyme/' doxyfile-path: 'doxygen.cfg' From 53c69356cffa225b19eb6118304bf6de131b3be0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:56:21 +0000 Subject: [PATCH 092/106] Bump actions/checkout from 3 to 4 (#1759) Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4. - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v3...v4) --- updated-dependencies: - dependency-name: actions/checkout dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/bcload.yml | 2 +- .github/workflows/benchmark.yml | 2 +- .github/workflows/ccpp.yml | 2 +- .github/workflows/doxygen.yml | 2 +- .github/workflows/enzyme-ci.yml | 6 +++--- .github/workflows/enzyme-julia.yml | 4 ++-- .github/workflows/enzyme-mlir.yml | 4 ++-- .github/workflows/format.yml | 2 +- .github/workflows/fortran.yml | 2 +- .github/workflows/tagger.yml | 4 ++-- 10 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/bcload.yml b/.github/workflows/bcload.yml index 33e409fd8bfe..9cbbdb7094ff 100644 --- a/.github/workflows/bcload.yml +++ b/.github/workflows/bcload.yml @@ -27,7 +27,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: cd enzyme && rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 46fa9edb1bed..d859468708b5 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -37,7 +37,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 46c64d5b369e..6efe084fdc07 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -34,7 +34,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/doxygen.yml b/.github/workflows/doxygen.yml index 8b6278c99a26..12e64dc5d8ce 100644 --- a/.github/workflows/doxygen.yml +++ b/.github/workflows/doxygen.yml @@ -9,7 +9,7 @@ jobs: docs: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: mattnotmitt/doxygen-action@v1.9.8 with: diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index fcb94bb604f3..5f8e9ea6e524 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -32,7 +32,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake @@ -71,7 +71,7 @@ jobs: brew update brew install llvm@${{ matrix.llvm }} make cmake sudo python3 -m pip install --upgrade pip lit requests - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake @@ -109,7 +109,7 @@ jobs: run: | brew install llvm@${{ matrix.llvm }} make cmake gcc sudo python3 -m pip install --upgrade pip lit - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/enzyme-julia.yml b/.github/workflows/enzyme-julia.yml index fcb6ff659a07..4aabf7204f5b 100644 --- a/.github/workflows/enzyme-julia.yml +++ b/.github/workflows/enzyme-julia.yml @@ -28,8 +28,8 @@ jobs: - x64 timeout-minutes: 60 steps: - - uses: actions/checkout@v3 - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - uses: actions/checkout@v4 with: repository: 'wsmoses/Enzyme.jl' path: ./jl diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index f36ed1671c46..aaced6f4e694 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -29,11 +29,11 @@ jobs: sudo apt-get update sudo apt-get install -y binutils ninja-build cmake gcc g++ python3 python3-dev - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: path: 'Enzyme' - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb' diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 55b2c99749be..c01757d0ff86 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: DoozyX/clang-format-lint-action@v0.16.2 with: source: 'enzyme/Enzyme enzyme/tools/enzyme-tblgen' diff --git a/.github/workflows/fortran.yml b/.github/workflows/fortran.yml index a31294c90b4f..e5e960345f63 100644 --- a/.github/workflows/fortran.yml +++ b/.github/workflows/fortran.yml @@ -40,7 +40,7 @@ jobs: sudo apt-get update && sudo apt-get install -y intel-oneapi-compiler-fortran-${{ matrix.ifx }} intel-oneapi-mpi-${{ matrix.mpi }} intel-oneapi-mpi-devel-${{ matrix.mpi }} source /opt/intel/oneapi/setvars.sh printenv >> $GITHUB_ENV - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: generate build system run: | rm -rf build && mkdir build && cd build diff --git a/.github/workflows/tagger.yml b/.github/workflows/tagger.yml index 9d712fb0cd6c..d40b3a1d943b 100644 --- a/.github/workflows/tagger.yml +++ b/.github/workflows/tagger.yml @@ -17,12 +17,12 @@ jobs: private_key: ${{ secrets.APP_PRIVATE_KEY }} repository: JuliaPackaging/Yggdrasil - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: 'JuliaPackaging/Yggdrasil' path: ygg - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: path: enz - name: replace From 1588442942fd2ed5971087034725d330b65cc1dd Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 27 Feb 2024 01:32:59 +0100 Subject: [PATCH 093/106] Update tagger.yml (#1762) --- .github/workflows/tagger.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tagger.yml b/.github/workflows/tagger.yml index d40b3a1d943b..e4a015c4145a 100644 --- a/.github/workflows/tagger.yml +++ b/.github/workflows/tagger.yml @@ -10,12 +10,12 @@ jobs: name: Enzyme Tag CI runs-on: ubuntu-latest steps: - - uses: tibdex/github-app-token@v1 + - uses: actions/create-github-app-token@v1 id: generate_token with: app_id: ${{ secrets.APP_ID }} private_key: ${{ secrets.APP_PRIVATE_KEY }} - repository: JuliaPackaging/Yggdrasil + repositories: JuliaPackaging/Yggdrasil - uses: actions/checkout@v4 with: From 3e2de5de6c03b9fe138d657ebfde42fa05af3a4e Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 26 Feb 2024 21:01:36 -0500 Subject: [PATCH 094/106] C++ error message for incorrect custom gradient type (#1764) --- enzyme/Enzyme/AdjointGenerator.h | 4 ++-- enzyme/Enzyme/EnzymeLogic.cpp | 38 +++++++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 9f1d3cd65dfc..a3abf6991e95 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4821,7 +4821,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (call.getAttributes().hasParamAttr(i, attr)) { if (gutils->getWidth() == 1) { structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); - } else if (attr == "enzymejl_returnRoots") { + } else if (attr == std::string("enzymejl_returnRoots")) { structAttrs[args.size()].push_back( Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); } @@ -5155,7 +5155,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (gutils->getWidth() == 1) { structAttrs[pre_args.size()].push_back( call.getParamAttr(i, attr)); - } else if (attr == "enzymejl_returnRoots") { + } else if (attr == std::string("enzymejl_returnRoots")) { structAttrs[pre_args.size()].push_back( Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 76fc71921654..f930d4e1375b 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3937,14 +3937,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient( hasTape = false; // res.first.push_back(StructType::get(todiff->getContext(), {})); } else { - llvm::errs() << "expected args: ["; + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Bad function type of custom reverse pass for function " + << key.todiff->getName() << " of type " + << *key.todiff->getFunctionType() << "\n"; + ss << " expected gradient function to have argument types ["; + bool seen = false; for (auto a : res.first) { - llvm::errs() << *a << " "; + if (seen) + ss << ", "; + seen = true; + ss << *a; + } + ss << "]\n"; + ss << " Instead found " << foundcalled->getName() << " of type " + << *foundcalled->getFunctionType() << "\n"; + Value *toshow = key.todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *key.todiff << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(key.todiff), + wrap(context.ip)); + } else if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + } else { + assert(0 && "bad type for custom gradient"); + llvm_unreachable("bad type for custom gradient"); } - llvm::errs() << "]\n"; - llvm::errs() << *foundcalled << "\n"; - assert(0 && "bad type for custom gradient"); - llvm_unreachable("bad type for custom gradient"); } auto st = dyn_cast(foundcalled->getReturnType()); From 744ede0c8525271cb7e977f91ff601e794f9c3ec Mon Sep 17 00:00:00 2001 From: Matin Raayai <30674652+matinraayai@users.noreply.github.com> Date: Tue, 27 Feb 2024 00:56:50 -0500 Subject: [PATCH 095/106] Fixed incorrect usage of llvm::Function::splice. (#1751) * Fixed incorrect usage of llvm::Function::splice. * Put an #ifdef to take into account the LLVM version when using llvm::Function::splice. * fix format --------- Co-authored-by: William S. Moses --- enzyme/Enzyme/Enzyme.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 055b6f394842..47e9f2bba91f 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -2110,8 +2110,12 @@ class EnzymeBase { // Move the truncated body into the original function F.deleteBody(); +#if LLVM_VERSION_MAJOR >= 16 + F.splice(F.begin(), TruncatedFunc); +#else F.getBasicBlockList().splice(F.begin(), TruncatedFunc->getBasicBlockList()); +#endif RemapFunction(F, Mapping, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); TruncatedFunc->deleteBody(); From ca06ef36c94015695a13b786cdce89a269ec338f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 27 Feb 2024 01:16:36 -0500 Subject: [PATCH 096/106] Allow custom importing of files and syntactic sugar (#1752) * Allow custom importing of files and syntactic sugar * Fix build on older llvm vers * Update sugar.cpp * Update EnzymeClang.cpp * fix * fixup * dump * more printing * print * fixup --------- Co-authored-by: Ivan Radanov Ivanov --- enzyme/BUILD | 15 + enzyme/Enzyme/CMakeLists.txt | 6 + enzyme/Enzyme/Clang/EnzymeClang.cpp | 37 ++ enzyme/Enzyme/Clang/include_utils.td | 458 ++++++++++++++++++ enzyme/Enzyme/Enzyme.cpp | 14 +- enzyme/Enzyme/Utils.cpp | 333 ++++++++++--- enzyme/Enzyme/Utils.h | 2 + enzyme/test/Integration/ReverseMode/sugar.cpp | 91 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 29 +- enzyme/tools/enzyme-tblgen/enzyme-tblgen.h | 1 + 10 files changed, 916 insertions(+), 70 deletions(-) create mode 100644 enzyme/Enzyme/Clang/include_utils.td create mode 100644 enzyme/test/Integration/ReverseMode/sugar.cpp diff --git a/enzyme/BUILD b/enzyme/BUILD index e582b435d387..dc36cb4ad0c5 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -142,6 +142,20 @@ gentbl( ], ) +gentbl( + name = "include-utils", + tbl_outs = [( + "-gen-header-strings", + "IncludeUtils.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/Clang/include_utils.td", + td_srcs = ["Enzyme/Clang/include_utils.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeStatic", srcs = glob( @@ -167,6 +181,7 @@ cc_library( data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], deps = [ + "include-utils", ":binop-derivatives", ":blas-attributor", ":blas-derivatives", diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 1cd6e84c5be1..b27e4beb08cd 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -37,6 +37,10 @@ add_public_tablegen_target(BlasDeclarationsIncGen) add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) +set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) +enzyme_tablegen(IncludeUtils.inc -gen-header-strings) +add_public_tablegen_target(IncludeUtilsIncGen) + include_directories(${CMAKE_CURRENT_BINARY_DIR}) set(LLVM_LINK_COMPONENTS Demangle) @@ -74,6 +78,7 @@ if (${Clang_FOUND}) LLVM ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp @@ -107,6 +112,7 @@ if (${Clang_FOUND}) clang ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index a34a6429dcf7..0072c958b517 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -25,16 +25,20 @@ #include "clang/AST/Attr.h" #include "clang/AST/DeclGroup.h" #include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Basic/FileManager.h" #include "clang/Basic/MacroBuilder.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/FrontendAction.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Lex/HeaderSearch.h" #include "clang/Lex/PreprocessorOptions.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" #include "../Utils.h" +#include "IncludeUtils.inc" + using namespace clang; #if LLVM_VERSION_MAJOR >= 18 @@ -134,6 +138,39 @@ class EnzymePlugin final : public clang::ASTConsumer { Builder.defineMacro("ENZYME_VERSION_PATCH", std::to_string(ENZYME_VERSION_PATCH)); CI.getPreprocessor().setPredefines(Predefines.str()); + + auto baseFS = &CI.getFileManager().getVirtualFileSystem(); + llvm::vfs::OverlayFileSystem *fuseFS( + new llvm::vfs::OverlayFileSystem(baseFS)); + IntrusiveRefCntPtr fs( + new llvm::vfs::InMemoryFileSystem()); + + struct tm y2k = {}; + + y2k.tm_hour = 0; + y2k.tm_min = 0; + y2k.tm_sec = 0; + y2k.tm_year = 100; + y2k.tm_mon = 0; + y2k.tm_mday = 1; + time_t timer = mktime(&y2k); + for (const auto &pair : include_headers) { + fs->addFile(StringRef(pair[0]), timer, + llvm::MemoryBuffer::getMemBuffer( + StringRef(pair[1]), StringRef(pair[0]), + /*RequiresNullTerminator*/ true)); + } + + fuseFS->pushOverlay(fs); + fuseFS->pushOverlay(baseFS); + CI.getFileManager().setVirtualFileSystem(fuseFS); + + auto DE = CI.getFileManager().getDirectoryRef("/enzymeroot"); + assert(DE); + auto DL = DirectoryLookup(*DE, SrcMgr::C_User, + /*isFramework=*/false); + CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL, + /*isAngled=*/true); } ~EnzymePlugin() {} void HandleTranslationUnit(ASTContext &context) override {} diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td new file mode 100644 index 000000000000..1c99d219ce69 --- /dev/null +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -0,0 +1,458 @@ +class Headers { + string filename = filename_; + string contents = contents_; +} + +def : Headers<"/enzymeroot/enzyme/utils", [{ +#pragma once + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +extern int enzyme_const_return; +extern int enzyme_active_return; +extern int enzyme_dup_return; + +extern int enzyme_primal_return; +extern int enzyme_noret; + +template +Return __enzyme_autodiff(T...); + +template +Return __enzyme_fwddiff(T...); + +#include + +namespace enzyme { + + struct nodiff{}; + + template + struct ReverseMode { + + }; + using Reverse = ReverseMode; + using ReverseWithPrimal = ReverseMode; + + template < typename T > + struct Active{ + T value; + Active(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct Duplicated{ + T value; + T shadow; + Duplicated(T &&v, T&& s) : value(v), shadow(s) {} + }; + + template < typename T > + struct Const{ + T value; + Const(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct type_info { + static constexpr bool is_active = false; + using type = nodiff; + }; + + template < typename T > + struct type_info < Active >{ + static constexpr bool is_active = true; + using type = T; + }; + + template < typename ... T > + struct concatenated; + + template < typename ... S, typename T, typename ... rest > + struct concatenated < tuple < S ... >, T, rest ... > { + using type = typename concatenated< tuple< S ..., T>, rest ... >::type; + }; + + template < typename T > + struct concatenated < T > { + using type = T; + }; + + // Yikes! + // slightly cleaner in C++20, with std::remove_cvref + template < typename ... T > + struct autodiff_return; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type>; + }; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple< + typename type_info::type, + typename concatenated< tuple< >, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type + >; + }; + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Active & arg) { + return enzyme::tuple{enzyme_out, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Const & arg) { + return enzyme::tuple{enzyme_const, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Active & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Const & arg) { + return enzyme::tuple{arg.value}; + } + + namespace detail { + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(T &&t); + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple>{get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple{get<1>(t), get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence) { + return f(enzyme::get(impl::forward(t))...); + } + + template < typename T > + struct default_ret_activity { + using type = Const; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template < typename T > + struct ret_global; + + template + struct ret_global> { + static constexpr int* value = &enzyme_const_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_active_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_dup_return; + }; + + template + struct ret_used; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_primal_return; + }; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_noret; + }; + + } // namespace detail + + + + template < typename return_type, typename function, typename ... enz_arg_types > + __attribute__((always_inline)) + auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::primal_apply_impl(f, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename function, typename ... arg_types> + auto primal_call(function && f, arg_types && ... args) { + return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); + } + + template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types > + __attribute__((always_inline)) + auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::rev_apply_impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } + + template < typename DiffMode, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); + using RetActivity = typename detail::default_ret_activity::type; + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } +} +}]>; + +def : Headers<"/enzymeroot/enzyme/type_traits", [{ +#pragma once + +#include + +namespace enzyme { + +// this is already in C++20, but we reimplement it here for older C++ versions +template < typename T > +struct remove_cvref { + using type = + typename std::remove_reference< + typename std::remove_cv< + T + >::type + >::type; +}; + +template < typename T > +using remove_cvref_t = typename remove_cvref::type; + +namespace impl { + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>& __t) noexcept + { return static_cast<_Tp&&>(__t); } + + /** + * @brief Forward an rvalue. + * @return The parameter cast to the specified type. + * + * This function is used to implement "perfect forwarding". + */ + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>&& __t) noexcept + { + static_assert(!std::is_lvalue_reference<_Tp>::value, + "enzyme::impl::forward must not be used to convert an rvalue to an lvalue"); + return static_cast<_Tp&&>(__t); + } + +} + +} +}]>; + +def : Headers<"/enzymeroot/enzyme/tuple", [{ +#pragma once + +///////////// +// tuple.h // +///////////// + +// why reinvent the wheel and implement a tuple class? +// - ensure data is laid out in the same order the types are specified +// see: https://github.com/EnzymeAD/Enzyme/issues/1191#issuecomment-1556239213 +// - CUDA compatibility: std::tuple has some compatibility issues when used +// in a __device__ context (this may get better in c++20 with the improved +// constexpr support for std::tuple). Owning the implementation lets +// us add __host__ __device__ annotations to any part of it + +#include // for std::integer_sequence + +#include + +#define _NOEXCEPT noexcept +namespace enzyme { + +template +struct Index {}; + +template +struct value_at_position { + __attribute__((always_inline)) + T & operator[](Index) { return value; } + + __attribute__((always_inline)) + constexpr const T & operator[](Index) const { return value; } + T value; +}; + +template +struct tuple_base; + +template +struct tuple_base, T...> + : public value_at_position... { + using value_at_position::operator[]...; +}; + +template +struct tuple : public tuple_base, T...> {}; + +template +__attribute__((always_inline)) +tuple(T ...) -> tuple; + +template < int i, typename Tuple > +__attribute__((always_inline)) +decltype(auto) get(Tuple && tup) { + constexpr bool is_lvalue = std::is_lvalue_reference_v; + constexpr bool is_const = std::is_const_v>; + using T = remove_cvref_t< decltype(tup[Index{ } ]) >; + if constexpr ( is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr ( is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } +} + +template < int i, typename ... T> +__attribute__((always_inline)) +decltype(auto) get(const tuple< T ... > & tup) { + return tup[Index{} ]; +} + +template +struct tuple_size; + +template +struct tuple_size> : std::integral_constant {}; + +template +static constexpr size_t tuple_size_v = tuple_size::value; + +template +__attribute__((always_inline)) +constexpr auto forward_as_tuple(T&&... args) noexcept { + return tuple{impl::forward(args)...}; +} + +namespace impl { + +template +struct make_tuple_from_fwd_tuple; + +template +struct make_tuple_from_fwd_tuple> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd) { + return tuple{get(impl::forward(fwd))...}; + } +}; + +template +struct concat_with_fwd_tuple; + +template < typename Tuple > +using iseq = std::make_index_sequence > >; + +template +struct concat_with_fwd_tuple, std::index_sequence> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) { + return forward_as_tuple(get(impl::forward(fwd))..., get(impl::forward(t))...); + } +}; + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(Tuple&& ret) { + return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(impl::forward< Tuple >(ret)); +} + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(FWD_TUPLE&& fwd, first&& t, rest&&... ts) { + return tuple_cat(concat_with_fwd_tuple< iseq, iseq >::f(impl::forward(fwd), impl::forward(t)), impl::forward(ts)...); +} + +} // namespace impl + +template +__attribute__((always_inline)) +constexpr auto tuple_cat(Tuples&&... tuples) { + return impl::tuple_cat(impl::forward(tuples)...); +} + +} // namespace enzyme +#undef _NOEXCEPT +}]>; + +def : Headers<"/enzymeroot/enzyme/enzyme", [{ +#ifdef __cplusplus +#include "enzyme/utils" +#else +#warning "Enzyme wrapper templates only available in C++" +#endif +}]>; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 47e9f2bba91f..70f173f5734a 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -435,6 +435,9 @@ std::optional getMetadataName(llvm::Value *res) Optional getMetadataName(llvm::Value *res) #endif { + if (auto S = simplifyLoad(res)) + return getMetadataName(S); + if (auto av = dyn_cast(res)) { return cast(av->getMetadata())->getString(); } else if ((isa(res) || isa(res)) && @@ -463,12 +466,11 @@ Optional getMetadataName(llvm::Value *res) return gv->getName(); } else if (auto gv = dyn_cast(res)) { return gv->getName(); - } else { - if (isa(res)) { - return recursePhiReads(cast(res)); - } - return {}; + } else if (isa(res)) { + return recursePhiReads(cast(res)); } + + return {}; } static Value *adaptReturnedVector(Value *ret, Value *diffret, @@ -3197,6 +3199,7 @@ AnalysisKey EnzymeNewPM::Key; #include "PreserveNVVM.h" #include "TypeAnalysis/TypeAnalysisPrinter.h" #include "llvm/Passes/PassBuilder.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" #if LLVM_VERSION_MAJOR >= 15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/IPO/CalledValuePropagation.h" @@ -3427,6 +3430,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { #else prePass(MPM); #endif + MPM.addPass(llvm::AlwaysInlinerPass()); FunctionPassManager OptimizerPM; FunctionPassManager OptimizerPM2; #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index ff44cbaa715d..283460673922 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2264,7 +2264,258 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm_unreachable("unknown inst2"); } -Function *GetFunctionFromValue(Value *fn) { +// Find the base pointer of ptr and the offset in bytes from the start of +// the returned base pointer to this value. +AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) { + offset = 0; + while (true) { + if (auto CI = dyn_cast(ptr)) { + ptr = CI->getOperand(0); + continue; + } + if (auto CI = dyn_cast(ptr)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + return nullptr; + } + offset += Offset.getZExtValue(); + ptr = CI->getOperand(0); + continue; + } + if (isa(ptr)) { + break; + } + if (auto LI = dyn_cast(ptr)) { + if (auto S = simplifyLoad(LI)) { + ptr = S; + continue; + } + } + return nullptr; + } + return cast(ptr); +} + +// Find all user instructions of AI, returning tuples of Unlike a simple get users, this will recurse through any +// constant gep offsets and casts +SmallVector, 1> +findAllUsersOf(Value *AI) { + SmallVector, 1> todo; + todo.emplace_back(AI, 0); + + SmallVector, 1> users; + while (todo.size()) { + auto pair = todo.pop_back_val(); + Value *ptr = pair.first; + size_t suboff = pair.second; + + for (auto U : ptr->users()) { + if (auto CI = dyn_cast(U)) { + todo.emplace_back(CI, suboff); + continue; + } + if (auto CI = dyn_cast(U)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + users.emplace_back(cast(U), ptr, suboff); + continue; + } + todo.emplace_back(CI, suboff + Offset.getZExtValue()); + continue; + } + users.emplace_back(cast(U), ptr, suboff); + continue; + } + } + return users; +} + +// Given a pointer, find all values of size `valSz` which could be loaded from +// that pointer when indexed at offset. If it is impossible to guarantee that +// the set contains all such values, set legal to false +SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, + size_t valSz, bool &legal) { + SmallVector options; + + auto todo = findAllUsersOf(ptr0); + std::set> seen; + + while (todo.size()) { + auto pair = todo.pop_back_val(); + if (seen.count(pair)) + continue; + seen.insert(pair); + Instruction *U = std::get<0>(pair); + Value *ptr = std::get<1>(pair); + size_t suboff = std::get<2>(pair); + + // Read only users do not set the memory inside of ptr + if (isa(U)) { + continue; + } + if (auto MTI = dyn_cast(U)) + if (MTI->getOperand(0) != ptr) { + continue; + } + if (auto I = dyn_cast(U)) { + if (!I->mayWriteToMemory() && I->getType()->isVoidTy()) + continue; + } + + if (auto SI = dyn_cast(U)) { + auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout(); + + // We are storing into the ptr + if (SI->getPointerOperand() == ptr) { + auto storeSz = + (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) / + 8; + // If store is before the load would start + if (storeSz + suboff <= offset) + continue; + // if store starts after load would start + if (offset + valSz <= suboff) + continue; + + if (valSz == storeSz) { + options.push_back(SI->getValueOperand()); + continue; + } + } + + // We capture our pointer of interest, if it is stored into an alloca, + // all loads of said alloca would potentially store into. + if (SI->getValueOperand() == ptr) { + if (suboff == 0) { + size_t mid_offset = 0; + if (auto AI2 = + getBaseAndOffset(SI->getPointerOperand(), mid_offset)) { + bool sublegal = true; + auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8; + auto subPtrs = + getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal); + if (!sublegal) { + legal = false; + return options; + } + for (auto subPtr : subPtrs) { + for (const auto &pair3 : findAllUsersOf(subPtr)) { + todo.emplace_back(pair3); + } + } + continue; + } + } + } + } + + if (auto II = dyn_cast(U)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + } + + // If we copy into the ptr at a location that includes the offset, consider + // all sub uses + if (auto MTI = dyn_cast(U)) { + if (auto CI = dyn_cast(MTI->getLength())) { + if (MTI->getOperand(0) == ptr && suboff == 0 && + CI->getValue().uge(offset + valSz)) { + size_t midoffset = 0; + auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset); + if (!AI2) { + legal = false; + return options; + } + if (midoffset != 0) { + legal = false; + return options; + } + for (const auto &pair3 : findAllUsersOf(AI2)) { + todo.emplace_back(pair3); + } + continue; + } + } + } + + legal = false; + return options; + } + + return options; +} + +// Perform mem2reg/sroa to identify the innermost value being represented. +Value *simplifyLoad(Value *V, size_t valSz) { + if (auto LI = dyn_cast(V)) { + if (valSz == 0) { + auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8; + } + + Value *ptr = LI->getPointerOperand(); + size_t offset = 0; + + if (auto ptr2 = simplifyLoad(ptr)) { + ptr = ptr2; + } + auto AI = getBaseAndOffset(ptr, offset); + if (!AI) { + return nullptr; + } + + bool legal = true; + auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal); + + if (!legal) { + return nullptr; + } + std::set res; + for (auto opt : opts) { + Value *v2 = simplifyLoad(opt, valSz); + if (v2) + res.insert(v2); + else + res.insert(opt); + } + if (res.size() != 1) { + return nullptr; + } + Value *retval = *res.begin(); + return retval; + } + if (auto EVI = dyn_cast(V)) { + bool allZero = true; + for (auto idx : EVI->getIndices()) { + if (idx != 0) + allZero = false; + } + if (valSz == 0) { + auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8; + } + if (allZero) + if (auto LI = dyn_cast(EVI->getAggregateOperand())) { + return simplifyLoad(LI, valSz); + } + } + return nullptr; +} + +Value *GetFunctionValFromValue(Value *fn) { while (!isa(fn)) { if (auto ci = dyn_cast(fn)) { fn = ci->getOperand(0); @@ -2294,6 +2545,7 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + val = GetFunctionValFromValue(val); if (isa(val)) { fn = val; continue; @@ -2315,6 +2567,14 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + while (isa(val)) { + auto v2 = simplifyLoad(val); + if (v2) { + val = v2; + continue; + } + break; + } if (isa(val)) { fn = val; continue; @@ -2326,73 +2586,18 @@ Function *GetFunctionFromValue(Value *fn) { } } } - if (auto LI = dyn_cast(fn)) { - auto obj = getBaseObject(LI->getPointerOperand()); - if (isa(obj)) { - std::set> done; - SmallVector, 1> todo; - Value *stored = nullptr; - bool legal = true; - for (auto U : obj->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, obj)); - else { - legal = false; - break; - } - } - while (legal && todo.size()) { - auto tup = todo.pop_back_val(); - if (done.count(tup)) - continue; - done.insert(tup); - auto cur = tup.first; - auto prev = tup.second; - if (auto SI = dyn_cast(cur)) - if (SI->getPointerOperand() == prev) { - if (stored == SI->getValueOperand()) - continue; - else if (stored == nullptr) { - stored = SI->getValueOperand(); - continue; - } else { - legal = false; - break; - } - } - - if (isPointerArithmeticInst(cur, /*includephi*/ true)) { - for (auto U : cur->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, cur)); - else { - legal = false; - break; - } - } - continue; - } - - if (isa(cur)) - continue; - - if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) - continue; - - legal = false; - break; - } - - if (legal && stored) { - fn = stored; - continue; - } - } + if (auto S = simplifyLoad(fn)) { + fn = S; + continue; } break; } - return dyn_cast(fn); + return fn; +} + +Function *GetFunctionFromValue(Value *fn) { + return dyn_cast(GetFunctionValFromValue(fn)); } #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 37edfc1a985d..5a4b3c31cef7 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1248,6 +1248,8 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Function *GetFunctionFromValue(llvm::Value *fn); +llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0); + static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) { auto F = getFunctionFromCall(CI); auto funcName = getFuncNameFromCall(CI); diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp new file mode 100644 index 000000000000..8524342e1bfc --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -0,0 +1,91 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +#include "../test_utils.h" + +#include + +double foo(double x, double y) { return x * y; } + +double square(double x) { return x * x; } + +struct pair { + double x; + double y; +}; + +int main() { + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq2 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq3 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq3_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq4 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq4_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + printf("dmul %f %f\n", y1, y2); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + auto prim = enzyme::get<1>(y); + printf("dmul2 %f %f\n", y1, y2); + printf("dmul_prim %f\n", prim); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + APPROX_EQ(prim, 2.7*3.1, 1e-10); + } + + { + auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, 3.1, enzyme_out, 2.7); + printf("dmul2 %f %f\n", z1, z2); + APPROX_EQ(z1, 2.7, 1e-10); + APPROX_EQ(z2, 3.1, 1e-10); + } + +} diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 04f63e1ad3ac..143c85ea684e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -65,7 +65,9 @@ static cl::opt cl::values(clEnumValN(MLIRDerivatives, "gen-mlir-derivatives", "Generate MLIR derivative")), cl::values(clEnumValN(CallDerivatives, "gen-call-derivatives", - "Generate call derivative"))); + "Generate call derivative")), + cl::values(clEnumValN(GenHeaderVariables, "gen-header-strings", + "Generate header strings"))); void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, StringRef FT, StringRef cconv, Init *func, @@ -1248,6 +1250,24 @@ void printDiffUse( } } +static void emitHeaderIncludes(const RecordKeeper &recordKeeper, + raw_ostream &os) { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers"); + os << "const char* include_headers[][2] = {\n"; + bool seen = false; + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename = pattern->getValueAsString("filename"); + auto contents = pattern->getValueAsString("contents"); + os << "{\"" << filename << "\"\n,"; + os << "R\"(" << contents << ")\"\n"; + os << "}"; + seen = true; + } + os << "};\n"; +} + static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); @@ -1268,6 +1288,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case BinopDerivatives: patternNames = "BinopPattern"; break; + case GenHeaderVariables: case GenBlasDerivatives: case UpdateBlasDecl: case UpdateBlasTA: @@ -1299,6 +1320,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case MLIRDerivatives: { auto opName = pattern->getValueAsString("opName"); @@ -2089,6 +2111,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDiffUse"); case CallDerivatives: patternNames = "CallPattern"; @@ -2127,6 +2150,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case CallDerivatives: { os << " if (("; @@ -2283,6 +2307,9 @@ static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { case UpdateBlasTA: emitBlasTAUpdater(records, os); return false; + case GenHeaderVariables: + emitHeaderIncludes(records, os); + return false; default: errs() << "unknown tablegen action!\n"; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h index 368644ba0b5d..742a96d023ae 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h @@ -24,6 +24,7 @@ enum ActionType { UpdateBlasDecl, UpdateBlasTA, GenBlasDiffUse, + GenHeaderVariables, }; void emitDiffUse(const llvm::RecordKeeper &recordKeeper, llvm::raw_ostream &os, From f36e5a15909ef3fcede6372c41161d184e3b6162 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:45:59 +0000 Subject: [PATCH 097/106] Bump peter-evans/create-pull-request from 3 to 6 (#1760) Bumps [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) from 3 to 6. - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/v3...v6) --- updated-dependencies: - dependency-name: peter-evans/create-pull-request dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Tim Gymnich --- .github/workflows/tagger.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tagger.yml b/.github/workflows/tagger.yml index e4a015c4145a..46f8874e92e3 100644 --- a/.github/workflows/tagger.yml +++ b/.github/workflows/tagger.yml @@ -37,7 +37,7 @@ jobs: git add . - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v3 + uses: peter-evans/create-pull-request@v6 with: path: ygg commit-message: "Upgrade enzyme to ${{ github.ref }}" From c38423bacd84bdc1684bad9b9b4f100294e28364 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 27 Feb 2024 15:23:06 +0100 Subject: [PATCH 098/106] Update MLIR to 2c9b6c1b36b8185299de083c3058e0c1e7760442 (#1765) Remove most uses of `ConversionPatternRewriter` that were spurious anyway. Update one use to use `IRRewriter` instead since the conversion rewriter should no longer be used directly. Clean up includes accordingly. --- .github/workflows/enzyme-mlir.yml | 2 +- .../Implementations/LinalgAutoDiffOpInterfaceImpl.cpp | 4 +--- enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp | 9 --------- enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp | 3 --- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 11 ++--------- enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp | 10 ---------- 6 files changed, 4 insertions(+), 35 deletions(-) diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index aaced6f4e694..34a7828af856 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -36,7 +36,7 @@ jobs: - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' - ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb' + ref: '2c9b6c1b36b8185299de083c3058e0c1e7760442' path: 'llvm-project' - name: Get MLIR commit hash diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 334d7325d7ec..8ecb851af90b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -79,9 +79,7 @@ struct GenericOpInterfaceReverse cast(gutils->getNewFromOriginal(linalgOp)); // Replace the op by a linalg.generic op if necessary - // TODO : IRRewriter rewriter(builder.getContext()/*, - // builder.getListener()*/); - ConversionPatternRewriter rewriter(builder.getContext()); + IRRewriter rewriter(builder.getContext(), builder.getListener()); auto failiureOrLinalgOp = generalizeNamedOp(rewriter, newOp); if (!failed(failiureOrLinalgOp)) { linalg::GenericOp replacement = failiureOrLinalgOp.value(); diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp index 448dd67b8dd3..5baf2f982d9d 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp @@ -11,18 +11,12 @@ // procedure to the MemRef dialect. //===----------------------------------------------------------------------===// -#include "Dialect/Dialect.h" #include "Dialect/Ops.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Rewrite/PatternApplicator.h" #include "llvm/Support/raw_ostream.h" #include "Interfaces/AutoDiffTypeInterface.h" @@ -51,9 +45,6 @@ SmallVector applyAffineMap(AffineMap aMap, SmallVector indices, struct AddToOpToIndexAndLoadPass : public enzyme::AddToOpToIndexAndLoadPassBase { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { auto loc = op->getLoc(); auto enzymeAdjoint = dyn_cast(op); diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp index de2ebeba376d..01bcd6683a96 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp @@ -105,9 +105,6 @@ void processGenericDuplication(Operation *op, OpBuilder &builder, Location loc, struct AddToOpToSplitPass : public enzyme::AddToOpToSplitPassBase { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { auto enzymeAdjoint = dyn_cast(op); auto loc = op->getLoc(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 1e2c55640280..e43db26b21a2 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -11,16 +11,13 @@ //===----------------------------------------------------------------------===// #include "Dialect/Ops.h" -#include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #define DEBUG_TYPE "enzyme" @@ -221,13 +218,9 @@ std::unique_ptr createDifferentiatePass() { } // namespace enzyme } // namespace mlir -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" - void DifferentiatePass::runOnOperation() { SymbolTableCollection symbolTable; symbolTable.getSymbolTable(getOperation()); - ConversionPatternRewriter B(getOperation()->getContext()); getOperation()->walk( [&](FunctionOpInterface op) { lowerEnzymeCalls(symbolTable, op); }); } diff --git a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp b/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp index b8a42f339280..d4e374d6c68f 100644 --- a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp +++ b/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp @@ -11,19 +11,12 @@ // procedure to the MemRef dialect. //===----------------------------------------------------------------------===// -#include "Dialect/Dialect.h" #include "Dialect/Ops.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Rewrite/PatternApplicator.h" - #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -34,9 +27,6 @@ struct ShadowedGradientToCachePass : public enzyme::ShadowedGradientToCachePassBase< ShadowedGradientToCachePass> { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { if (auto initOp = dyn_cast(op)) { if (auto type = From 21ead64a8879163f4a14c67a32e9f655c8eb99b7 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 27 Feb 2024 16:16:37 +0100 Subject: [PATCH 099/106] Fix clang-tidy findings (#1766) Includes a real memory problem: ValueRange is a non-owning container, assining the returned value of type SmallVector to ValueRange will lead to ValueRange starting with a dangling pointer. --- enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp | 4 ++-- enzyme/Enzyme/MLIR/Passes/Passes.h | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp index 5baf2f982d9d..648509f58e52 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp @@ -85,7 +85,7 @@ struct AddToOpToIndexAndLoadPass // auto load = cacheBuilder.create(loc, inputs[i], map[i], // indices); auto store = cacheBuilder.create(loc, load, // inputs[i], map[i], indices); - ValueRange mapAppliedIndices = + SmallVector mapAppliedIndices = applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc); auto load = cacheBuilder.create(loc, outs[i], mapAppliedIndices); @@ -96,7 +96,7 @@ struct AddToOpToIndexAndLoadPass } for (int i = 0; i < retargs.size(); i++) { - ValueRange mapAppliedIndices = + SmallVector mapAppliedIndices = applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc); auto load = cacheBuilder.create(loc, outs[i], mapAppliedIndices); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 25362d01294a..80c88373090c 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -65,12 +65,15 @@ class MemRefDialect; namespace func { class FuncDialect; -} +} // end namespace func +namespace affine { class AffineDialect; +} // end namespace affine + namespace LLVM { class LLVMDialect; -} +} // end namespace LLVM #define GEN_PASS_REGISTRATION #include "Passes/Passes.h.inc" From 8d75756c0792ffabf07b239440bc5c1fa48fb799 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 27 Feb 2024 16:53:36 -0500 Subject: [PATCH 100/106] Correct complex inv (#1767) --- enzyme/Enzyme/InstructionDerivatives.td | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 32b30bd8f9b4..6491fa032798 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -162,12 +162,20 @@ def CFNeg : SubRoutine<(Op (Op $re, $im):$z), (FNeg $re), (FNeg $im) )>; + +def Conj : SubRoutine<(Op (Op $re, $im):$z), + (ArrayRet + $re, + (FNeg $im) + )>; + def CFExp : SubRoutine<(Op (Op $re, $im):$z), (ArrayRet (FMul (FExp $re):$exp, (FCos $im)), (FMul $exp, (FSin $im)) )>; + // Same function as the one being called def SameFunc { } @@ -826,9 +834,10 @@ def : CallPattern<(Op (Op $x, $y):$z), def : CallPattern<(Op (Op $x, $y):$z), ["cmplx_inv"], [ - (CFDiv (CFNeg (DiffeRet)), (CFMul $z, $z)), + // Reverse mode needs to return the conjugate + (Conj (CFDiv (CFNeg (Conj (DiffeRet))), (CFMul $z, $z))), ], - (ForwardFromSummedReverse), + (CFDiv (CFNeg (Shadow $z)), (CFMul $z, $z)), [ReadNone, NoUnwind] >; From f7a46fd41562e13a622613f63b12fa67418213bf Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 27 Feb 2024 23:46:11 +0100 Subject: [PATCH 101/106] [mlir] move alias lattice update to relevant transfer function (#1573) --- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp | 21 ++++++++++++------- .../MLIR/AliasAnalysis/func_attributes.mlir | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index 767a60a3ed2b..9af864e592d6 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -366,6 +366,19 @@ void enzyme::PointsToPointerAnalysis::visitOperation(Operation *op, SmallVector effects; memory.getEffects(effects); + + // If the operation allocates fresh memory and doesn't write into it, that + // memory is known not to point to any known alias class. + if (effects.size() == 1 && + isa(effects.front().getEffect()) && + effects.front().getValue()) { + const auto *destClasses = + getOrCreateFor(op, effects.front().getValue()); + propagateIfChanged( + after, after->setPointingToEmpty(destClasses->getAliasClassesObject())); + return; + } + for (const auto &effect : effects) { if (!isa(effect.getEffect())) continue; @@ -860,14 +873,6 @@ void enzyme::AliasAnalysis::transfer( result->getPoint(), originalClasses.getOriginalClass(result->getPoint(), debugLabel)); propagateIfChanged(result, result->join(fresh)); - - // The pointer to freshly allocated memory is known not to point to - // anything. - // TODO(zinenko): this is a bit strange to update _another_ lattice - // here. - auto *pointsTo = getOrCreate(op); - propagateIfChanged(pointsTo, pointsTo->setPointingToEmpty( - fresh.getAliasClassesObject())); } } } else if (isa(effect.getEffect())) { diff --git a/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir b/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir index 87c7e4105a2b..e1f4b74a21f4 100644 --- a/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir +++ b/enzyme/test/MLIR/AliasAnalysis/func_attributes.mlir @@ -115,7 +115,7 @@ func.func private @callee(%ptr : !llvm.ptr {llvm.readonly}) attributes { } // CHECK: points-to-pointer sets -// CHECK-NEXT: +// CHECK: // CHECK-LABEL @call_other_none_arg_rw_readonly func.func @call_other_none_arg_rw_readonly(%input: !llvm.ptr {enzyme.tag = "input"}) { call @callee(%input) : (!llvm.ptr) -> () From 070601e80f26a2b30fb7742d07fb8ed91785f6ff Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 28 Feb 2024 10:28:15 -0500 Subject: [PATCH 102/106] Fix reverse mode complex error function (#1770) --- enzyme/Enzyme/InstructionDerivatives.td | 12 ++++++------ enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll | 12 +++++++----- enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll | 12 +++++++----- enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll | 12 +++++++----- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 6491fa032798..e309055c34d2 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -655,30 +655,30 @@ def ToStruct2 : SubRoutine<(Op (Op $re, $im):$z), def : CallPattern<(Op $x, $tbd), ["Faddeeva_erf"], [ - (ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), + (ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x))))))), (InactiveArg) // relerr ], - (ForwardFromSummedReverse), + (ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x, $tbd), ["Faddeeva_erfi"], [ - (ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x))))), + (ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x)))))), (InactiveArg) // relerr ], - (ForwardFromSummedReverse), + (ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x))))), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x, $tbd), ["Faddeeva_erfc"], [ - (ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), + (ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x))))))), (InactiveArg) // relerr ], - (ForwardFromSummedReverse), + (ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), [ReadNone, NoUnwind] >; diff --git a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll index 4722ef548bd8..716cba079e61 100644 --- a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll +++ b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn) ; CHECK-NEXT: entry: +; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 +; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 +; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a17]] ; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0 ; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1 ; CHECK-DAG: %[[a2:.+]] = fmul fast double %[[a1]], %[[a1]] @@ -36,16 +39,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK-NEXT: %[[a13:.+]] = fmul fast double %[[a9]], %[[a12]] ; CHECK-NEXT: %[[a14:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[a11]] ; CHECK-NEXT: %[[a15:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[a13]] -; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 -; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 ; CHECK-DAG: %[[a19:.+]] = fmul fast double %[[a16]], %[[a14]] -; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[a17]], %[[a15]] +; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[conj]], %[[a15]] ; CHECK-NEXT: %[[a20:.+]] = fsub fast double %[[a19]], %[[a18]] ; CHECK-DAG: %[[a22:.+]] = fmul fast double %[[a16]], %[[a15]] -; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[a17]] +; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[conj]] ; CHECK-NEXT: %[[a23:.+]] = fadd fast double %[[a22]], %[[a21]] +; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a23]] ; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[a20]], 0 -; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[a23]], 1 +; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1 ; CHECK-NEXT: %[[a24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0 ; CHECK-NEXT: ret { { double, double } } %[[a24]] ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll index 04c8fee4de92..dabd2674b7ba 100644 --- a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll +++ b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn) ; CHECK-NEXT: entry: +; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 +; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 +; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a17]] ; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0 ; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1 ; CHECK-DAG: %[[a2:.+]] = fmul fast double %[[a1]], %[[a1]] @@ -36,16 +39,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK-NEXT: %[[a13:.+]] = fmul fast double %[[a9]], %[[a12]] ; CHECK-NEXT: %[[a14:.+]] = fmul fast double 0xBFF20DD750429B6D, %[[a11]] ; CHECK-NEXT: %[[a15:.+]] = fmul fast double 0xBFF20DD750429B6D, %[[a13]] -; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 -; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 ; CHECK-DAG: %[[a19:.+]] = fmul fast double %[[a16]], %[[a14]] -; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[a17]], %[[a15]] +; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[conj]], %[[a15]] ; CHECK-NEXT: %[[a20:.+]] = fsub fast double %[[a19]], %[[a18]] ; CHECK-DAG: %[[a22:.+]] = fmul fast double %[[a16]], %[[a15]] -; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[a17]] +; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[conj]] ; CHECK-NEXT: %[[a23:.+]] = fadd fast double %[[a22]], %[[a21]] +; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a23]] ; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[a20]], 0 -; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[a23]], 1 +; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1 ; CHECK-NEXT: %[[a24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0 ; CHECK-NEXT: ret { { double, double } } %[[a24]] ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll index 40932ffb5cf2..c5e5cfba0a90 100644 --- a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll +++ b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn) ; CHECK-NEXT: entry: +; CHECK-NEXT: %[[i16:.+]] = extractvalue { double, double } %differeturn, 0 +; CHECK-NEXT: %[[i17:.+]] = extractvalue { double, double } %differeturn, 1 +; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[i17]] ; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0 ; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1 ; CHECK-NEXT: %[[a3:.+]] = fmul fast double %[[a0]], %[[a0]] @@ -34,16 +37,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK-NEXT: %[[i13:.+]] = fmul fast double %[[i9]], %[[i12]] ; CHECK-NEXT: %[[i14:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[i11]] ; CHECK-NEXT: %[[i15:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[i13]] -; CHECK-NEXT: %[[i16:.+]] = extractvalue { double, double } %differeturn, 0 -; CHECK-NEXT: %[[i17:.+]] = extractvalue { double, double } %differeturn, 1 ; CHECK-NEXT: %[[i19:.+]] = fmul fast double %[[i16]], %[[i14]] -; CHECK-NEXT: %[[i18:.+]] = fmul fast double %[[i17]], %[[i15]] +; CHECK-NEXT: %[[i18:.+]] = fmul fast double %[[conj]], %[[i15]] ; CHECK-NEXT: %[[i20:.+]] = fsub fast double %[[i19]], %[[i18]] ; CHECK-NEXT: %[[i22:.+]] = fmul fast double %[[i16]], %[[i15]] -; CHECK-NEXT: %[[i21:.+]] = fmul fast double %[[i14]], %[[i17]] +; CHECK-NEXT: %[[i21:.+]] = fmul fast double %[[i14]], %[[conj]] ; CHECK-NEXT: %[[i23:.+]] = fadd fast double %[[i22]], %[[i21]] +; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[i23]] ; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[i20]], 0 -; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[i23]], 1 +; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1 ; CHECK-NEXT: %[[i24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0 ; CHECK-NEXT: ret { { double, double } } %[[i24]] ; CHECK-NEXT: } From b97aa9d5f0a4264308d03901c81e733c13042106 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Wed, 28 Feb 2024 19:22:48 -0800 Subject: [PATCH 103/106] Trucation to MPFR (#1750) * WIP MPFR truncation * MPFR truncation * Fix mpfr function mangling * Mangling * MPFR Wrappers * clang-format * Make header work in C * File header * Make it compile on llvm 11 * header * fix tests * Add TODO comment * more comments * MPFR header fix * Add mpfr test * Move mpfr runtime * Add another type of include header * fix tests * clang-format * Check for MPFR * Fix older llvm vers * llvm 11 * . * WIP deps * Proper include * Dep * Switch to inline header --- .github/workflows/ccpp.yml | 1 + enzyme/CMakeLists.txt | 7 + enzyme/Enzyme/Clang/include_utils.td | 125 +++++++ enzyme/Enzyme/Enzyme.cpp | 51 ++- enzyme/Enzyme/EnzymeLogic.cpp | 308 +++++++++--------- enzyme/Enzyme/EnzymeLogic.h | 68 +++- enzyme/test/Enzyme/Truncate/cmp.ll | 15 +- enzyme/test/Enzyme/Truncate/intrinsic.ll | 129 +++++--- enzyme/test/Enzyme/Truncate/select.ll | 6 +- enzyme/test/Enzyme/Truncate/simple.ll | 29 +- .../Integration/Truncate/truncate-all.cpp | 20 +- enzyme/test/lit.site.cfg.py.in | 6 + 12 files changed, 520 insertions(+), 245 deletions(-) diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 6efe084fdc07..1b5b293a24b7 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -27,6 +27,7 @@ jobs: - name: add llvm run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-get install -y libmpfr-dev sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev lld-${{ matrix.llvm }} clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev libzstd-dev sudo python3 -m pip install --upgrade pip lit diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 077f4f3a9554..2ab67dd3d2fb 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) project(Enzyme) include(CMakePackageConfigHelpers) +include(CheckIncludeFile) +include(CheckIncludeFileCXX) set(ENZYME_MAJOR_VERSION 0) set(ENZYME_MINOR_VERSION 0) @@ -265,6 +267,11 @@ string(REPLACE "};\n}" "};\n}}" INPUT_TEXT "${INPUT_TEXT}") string(REPLACE "const SCEV* S;\n};\n" "const SCEV* S;\n};\n}\n" INPUT_TEXT "${INPUT_TEXT}") endif() +find_library(MPFR_LIB_PATH mpfr) +CHECK_INCLUDE_FILE("mpfr.h" HAS_MPFR_H) +message("MPFR lib: " ${MPFR_LIB_PATH}) +message("MPFR header: " ${HAS_MPFR_H}) + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolutionExpander.h" "${INPUT_TEXT}") include_directories("${CMAKE_CURRENT_BINARY_DIR}/include") diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td index 1c99d219ce69..cb7cdd839c20 100644 --- a/enzyme/Enzyme/Clang/include_utils.td +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -456,3 +456,128 @@ def : Headers<"/enzymeroot/enzyme/enzyme", [{ #warning "Enzyme wrapper templates only available in C++" #endif }]>; + +def : Headers<"/enzymeroot/enzyme/mpfr", [{ +//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// TODO s +// +// (for MPFR ver. 2.1) +// +// We need to set the range of the allowed exponent using `mpfr_set_emin` and +// `mpfr_set_emax`. (This means we can also play with whether the range is +// centered around 0 (1?) or somewhere else) +// +// (also these need to be mutex'ed as the exponent change is global in mpfr and +// not float-specific) ... (mpfr seems to have thread safe mode - check if it is +// enabled or if it is enabled by default) +// +// For that we need to do this check: +// If the user changes the exponent range, it is her/his responsibility to +// check that all current floating-point variables are in the new allowed +// range (for example using mpfr_check_range), otherwise the subsequent +// behavior will be undefined, in the sense of the ISO C standard. +// +// MPFR docs state the following: +// Note: Overflow handling is still experimental and currently implemented +// partially. If an overflow occurs internally at the wrong place, anything +// can happen (crash, wrong results, etc). +// +// Which we would like to avoid somehow. +// +// MPFR also has this limitation that we need to address for accurate +// simulation: +// [...] subnormal numbers are not implemented. +// + +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN +#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \ + double, d, double, d, ROUNDING_MODE) +#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ + MPFR_FUNC_NAME) \ + __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) + +__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d, + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +}]>; diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 70f173f5734a..4ab84bd6e67e 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -58,6 +58,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -1339,21 +1340,40 @@ class EnzymeBase { Function *F = parseFunctionParameter(CI); if (!F) return false; - if (CI->arg_size() != 3) { + unsigned ArgSize = CI->arg_size(); + if (ArgSize != 4 && ArgSize != 3) { EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, "Had incorrect number of args to __enzyme_truncate_func", *CI, - " - expected 3"); + " - expected 3 or 4"); return false; } - auto Cfrom = cast(CI->getArgOperand(1)); - assert(Cfrom); - auto Cto = cast(CI->getArgOperand(2)); - assert(Cto); + FloatTruncation truncation = [&]() -> FloatTruncation { + if (ArgSize == 3) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto = cast(CI->getArgOperand(2)); + assert(Cto); + return FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue())); + } else if (ArgSize == 4) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto_exponent = cast(CI->getArgOperand(2)); + assert(Cto_exponent); + auto Cto_significand = cast(CI->getArgOperand(3)); + assert(Cto_significand); + return FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + FloatRepresentation( + (unsigned)Cto_exponent->getValue().getZExtValue(), + (unsigned)Cto_significand->getValue().getZExtValue())); + } + llvm_unreachable("??"); + }(); + RequestContext context(CI, &Builder); - llvm::Value *res = Logic.CreateTruncateFunc( - context, F, - getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), - getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), mode); + llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); @@ -2052,14 +2072,12 @@ class EnzymeBase { } bool handleFullModuleTrunc(Function &F) { - typedef std::vector> - TruncationsTy; + typedef std::vector TruncationsTy; static TruncationsTy FullModuleTruncs = []() -> TruncationsTy { StringRef ConfigStr(EnzymeTruncateAll); auto Invalid = [=]() { // TODO emit better diagnostic - llvm::errs() << "error: invalid format for truncation config\n"; - abort(); + llvm::report_fatal_error("error: invalid format for truncation config"); }; // "64" or "11-52" @@ -2102,9 +2120,8 @@ class EnzymeBase { for (auto Truncation : FullModuleTruncs) { IRBuilder<> Builder(F.getContext()); RequestContext context(&*F.getEntryBlock().begin(), &Builder); - Function *TruncatedFunc = - Logic.CreateTruncateFunc(context, &F, Truncation.first, - Truncation.second, TruncOpFullModuleMode); + Function *TruncatedFunc = Logic.CreateTruncateFunc( + context, &F, Truncation, TruncOpFullModuleMode); ValueToValueMapTy Mapping; for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args())) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index f930d4e1375b..f8fdf3b3124a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -32,7 +32,13 @@ #include "AdjointGenerator.h" #include "EnzymeLogic.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/ErrorHandling.h" +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -4956,30 +4962,28 @@ Function *EnzymeLogic::CreateForwardDiff( } static Value *floatValTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, - FloatRepresentation to) { - Type *toTy = to.getType(B.getContext()); + FloatTruncation truncation) { + Type *toTy = truncation.getToType(B.getContext()); if (auto vty = dyn_cast(v->getType())) toTy = VectorType::get(toTy, vty->getElementCount()); return B.CreateFPTrunc(v, toTy, "enzyme_trunc"); } static Value *floatValExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, FloatRepresentation to) { - Type *fromTy = from.getBuiltinType(B.getContext()); + FloatTruncation truncation) { + Type *fromTy = truncation.getFromType(B.getContext()); if (auto vty = dyn_cast(v->getType())) fromTy = VectorType::get(fromTy, vty->getElementCount()); return B.CreateFPExt(v, fromTy, "enzyme_exp"); } static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, - FloatRepresentation to) { + FloatTruncation truncation) { if (isa(v->getType())) report_fatal_error("vector operations not allowed in mem trunc mode"); - Type *fromTy = from.getBuiltinType(B.getContext()); - Type *toTy = to.getType(B.getContext()); + Type *fromTy = truncation.getFromType(B.getContext()); + Type *toTy = truncation.getToType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); B.CreateStore( @@ -4989,15 +4993,15 @@ static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, } static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - FloatRepresentation from, FloatRepresentation to) { + FloatTruncation truncation) { if (isa(v->getType())) report_fatal_error("vector operations not allowed in mem trunc mode"); - Type *fromTy = from.getBuiltinType(B.getContext()); + Type *fromTy = truncation.getFromType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); auto c0 = Constant::getNullValue( - llvm::Type::getIntNTy(B.getContext(), from.getTypeWidth())); + llvm::Type::getIntNTy(B.getContext(), truncation.getFromTypeWidth())); B.CreateStore( c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); B.CreateStore( @@ -5009,8 +5013,7 @@ static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, class TruncateGenerator : public llvm::InstVisitor { private: ValueToValueMapTy &originalToNewFn; - FloatRepresentation from; - FloatRepresentation to; + FloatTruncation truncation; Type *fromType; Type *toType; Function *oldFunc; @@ -5018,23 +5021,37 @@ class TruncateGenerator : public llvm::InstVisitor { AllocaInst *tmpBlock; TruncateMode mode; EnzymeLogic &Logic; + LLVMContext &ctx; public: TruncateGenerator(ValueToValueMapTy &originalToNewFn, - FloatRepresentation from, FloatRepresentation to, - Function *oldFunc, Function *newFunc, TruncateMode mode, - EnzymeLogic &Logic) - : originalToNewFn(originalToNewFn), from(from), to(to), oldFunc(oldFunc), - newFunc(newFunc), mode(mode), Logic(Logic) { + FloatTruncation truncation, Function *oldFunc, + Function *newFunc, TruncateMode mode, EnzymeLogic &Logic) + : originalToNewFn(originalToNewFn), truncation(truncation), + oldFunc(oldFunc), newFunc(newFunc), mode(mode), Logic(Logic), + ctx(newFunc->getContext()) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - fromType = from.getBuiltinType(B.getContext()); - toType = to.getType(B.getContext()); + fromType = truncation.getFromType(ctx); + toType = truncation.getToType(ctx); + if (fromType == toType) + assert(truncation.isToMPFR()); if (mode == TruncMemMode) tmpBlock = B.CreateAlloca(fromType); else tmpBlock = nullptr; + + if (truncation.isToMPFR()) { + switch (mode) { + case TruncMemMode: + llvm::report_fatal_error( + "truncation to MPFR not supported in memory mode."); + case TruncOpMode: + case TruncOpFullModuleMode: + break; + } + } } void checkHandled(llvm::Instruction &inst) { @@ -5065,25 +5082,26 @@ class TruncateGenerator : public llvm::InstVisitor { Value *truncate(IRBuilder<> &B, Value *v) { switch (mode) { case TruncMemMode: - return floatMemTruncate(B, v, tmpBlock, from, to); + assert(!truncation.isToMPFR()); + return floatMemTruncate(B, v, tmpBlock, truncation); case TruncOpMode: case TruncOpFullModuleMode: - return floatValTruncate(B, v, tmpBlock, from, to); - default: - llvm_unreachable("Unknown trunc mode"); + if (truncation.isToMPFR()) + return v; + return floatValTruncate(B, v, tmpBlock, truncation); } + llvm_unreachable("Unknown trunc mode"); } Value *expand(IRBuilder<> &B, Value *v) { switch (mode) { case TruncMemMode: - return floatMemExpand(B, v, tmpBlock, from, to); + return floatMemExpand(B, v, tmpBlock, truncation); case TruncOpMode: case TruncOpFullModuleMode: - return floatValExpand(B, v, tmpBlock, from, to); - default: - llvm_unreachable("Unknown trunc mode"); + return floatValExpand(B, v, tmpBlock, truncation); } + llvm_unreachable("Unknown trunc mode"); } void todo(llvm::Instruction &I) { @@ -5129,26 +5147,35 @@ class TruncateGenerator : public llvm::InstVisitor { void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { - Value *newCI = nullptr; - auto newI = getNewFromOriginal(&CI); - std::string oldName = CI.getName().str(); - newI->setName(""); - if (CI.getSrcTy() == getFromType()) { - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); - } - if (CI.getDestTy() == getToType()) { + switch (mode) { + case TruncMemMode: { + Value *newCI = nullptr; auto newI = getNewFromOriginal(&CI); - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); + std::string oldName = CI.getName().str(); + newI->setName(""); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (CI.getDestTy() == getToType()) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (newCI) { + newI->replaceAllUsesWith(newCI); + newI->eraseFromParent(); + } + return; } - if (newCI) { - newI->replaceAllUsesWith(newCI); - newI->eraseFromParent(); + case TruncOpMode: + case TruncOpFullModuleMode: + return; } - return; } void visitSelectInst(llvm::SelectInst &SI) { switch (mode) { @@ -5168,16 +5195,61 @@ class TruncateGenerator : public llvm::InstVisitor { case TruncOpMode: case TruncOpFullModuleMode: return; - default: - llvm_unreachable(""); } + llvm_unreachable(""); } void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { return; } void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } + CallInst *createMPFRCall(llvm::IRBuilder<> &B, llvm::Instruction &I, + llvm::Type *RetTy, + SmallVectorImpl &ArgsIn) { + std::string Name; + if (auto BO = dyn_cast(&I)) { + Name = "binop_" + std::string(BO->getOpcodeName()); + } else if (auto II = dyn_cast(&I)) { + auto FOp = II->getCalledFunction(); + assert(FOp); + Name = "intr_" + std::string(FOp->getName()); + for (auto &C : Name) + if (C == '.') + C = '_'; + } else if (auto CI = dyn_cast(&I)) { + if (auto F = CI->getCalledFunction()) + Name = "func_" + std::string(F->getName()); + else + llvm_unreachable( + "Unexpected indirect call inst for conversion to MPFR"); + } else { + llvm_unreachable("Unexpected instruction for conversion to MPFR"); + } + + std::string MangledName = + std::string("__enzyme_mpfr_") + truncation.mangleFrom() + "_" + Name; + auto F = newFunc->getParent()->getFunction(MangledName); + SmallVector Args(ArgsIn.begin(), ArgsIn.end()); + Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); + Args.push_back(B.getInt64(truncation.getTo().significandWidth)); + if (!F) { + SmallVector ArgTypes; + for (auto Arg : Args) + ArgTypes.push_back(Arg->getType()); + FunctionType *FnTy = + FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); + F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, + newFunc->getParent()); + } + return cast(B.CreateCall(F, Args)); + } void visitBinaryOperator(llvm::BinaryOperator &BO) { + auto oldLHS = BO.getOperand(0); + auto oldRHS = BO.getOperand(1); + + if (oldLHS->getType() != getFromType() && + oldRHS->getType() != getFromType()) + return; switch (BO.getOpcode()) { default: @@ -5195,60 +5267,25 @@ class TruncateGenerator : public llvm::InstVisitor { case BinaryOperator::And: case BinaryOperator::Or: case BinaryOperator::Xor: + assert(0 && "Invalid binop opcode for float arg"); return; } - if (to.getBuiltinType(BO.getContext())) { - auto newI = getNewFromOriginal(&BO); - IRBuilder<> B(newI); - auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); - auto newRHS = truncate(B, getNewFromOriginal(BO.getOperand(1))); - switch (BO.getOpcode()) { - default: - break; - case BinaryOperator::FMul: { - auto nres = cast(B.CreateFMul(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FAdd: { - auto nres = cast(B.CreateFAdd(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FSub: { - auto nres = cast(B.CreateFSub(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FDiv: { - auto nres = cast(B.CreateFDiv(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FRem: { - auto nres = cast(B.CreateFRem(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - } + auto newI = getNewFromOriginal(&BO); + IRBuilder<> B(newI); + auto newLHS = truncate(B, getNewFromOriginal(oldLHS)); + auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); + Instruction *nres = nullptr; + if (truncation.isToMPFR()) { + SmallVector Args({newLHS, newRHS}); + nres = createMPFRCall(B, BO, truncation.getToType(ctx), Args); + } else { + nres = cast(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS)); } - todo(BO); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); return; } void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } @@ -5271,13 +5308,14 @@ class TruncateGenerator : public llvm::InstVisitor { void visitFenceInst(llvm::FenceInst &FI) { return; } bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) { + auto newI = cast(getNewFromOriginal(&CI)); + IRBuilder<> B(newI); + SmallVector orig_ops(CI.arg_size()); for (unsigned i = 0; i < CI.arg_size(); ++i) orig_ops[i] = CI.getOperand(i); bool hasFromType = false; - auto newI = cast(getNewFromOriginal(&CI)); - IRBuilder<> B(newI); SmallVector new_ops(CI.arg_size()); for (unsigned i = 0; i < CI.arg_size(); ++i) { if (orig_ops[i]->getType() == getFromType()) { @@ -5296,12 +5334,16 @@ class TruncateGenerator : public llvm::InstVisitor { if (!hasFromType) return false; - // TODO check that the intrinsic is overloaded - - CallInst *intr; - Value *nres = intr = - createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); - if (CI.getType() == getFromType()) + Instruction *intr = nullptr; + Value *nres = nullptr; + if (truncation.isToMPFR()) { + nres = intr = createMPFRCall(B, CI, retTy, new_ops); + } else { + // TODO check that the intrinsic is overloaded + nres = intr = + createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); + } + if (newI->getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); newI->replaceAllUsesWith(nres); @@ -5390,7 +5432,7 @@ class TruncateGenerator : public llvm::InstVisitor { Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) - return Logic.CreateTruncateFunc(ctx, F, from, to, mode); + return Logic.CreateTruncateFunc(ctx, F, truncation, mode); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; @@ -5457,10 +5499,11 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, Value *converted = nullptr; if (isTruncate) - converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, from, to); + converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, + FloatTruncation(from, to)); else - converted = - B.CreateFPExt(floatMemTruncate(B, v, nullptr, from, to), fromTy); + converted = B.CreateFPExt( + floatMemTruncate(B, v, nullptr, FloatTruncation(from, to)), fromTy); assert(converted); context.req->replaceAllUsesWith(converted); @@ -5471,13 +5514,9 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm::Function *totrunc, - FloatRepresentation from, - FloatRepresentation to, + FloatTruncation truncation, TruncateMode mode) { - if (from == to) - return totrunc; - - TruncateCacheKey tup(totrunc, from, to, mode); + TruncateCacheKey tup(totrunc, truncation, mode); if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { return TruncateCachedFunctions.find(tup)->second; } @@ -5492,10 +5531,9 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, Type *NewTy = totrunc->getReturnType(); FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); - std::string truncName = std::string("__enzyme_done_truncate_") + - (mode == TruncMemMode ? "mem" : "op") + "_func_" + - from.to_string() + "_" + to.to_string() + "_" + - totrunc->getName().str(); + std::string truncName = + std::string("__enzyme_done_truncate_") + truncateModeStr(mode) + + "_func_" + truncation.mangleTruncation() + "_" + totrunc->getName().str(); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, totrunc->getParent()); @@ -5530,34 +5568,6 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } - // TODO This is overloaded an doesnt do what it should do here - if (from < to) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Cannot truncate into a large width\n"; - llvm::Value *toshow = totrunc; - if (context.req) { - toshow = context.req; - ss << " at context: " << *context.req; - } else { - ss << *totrunc << "\n"; - } - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(toshow), - ErrorType::NoDerivative, nullptr, wrap(totrunc), - wrap(context.ip)); - return NewF; - } - if (context.req) { - EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, - ss.str()); - return NewF; - } - llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; - llvm::errs() << *totrunc << "\n"; - llvm_unreachable("attempting to truncate function without definition"); - } - ValueToValueMapTy originalToNewFn; for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); @@ -5579,7 +5589,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - TruncateGenerator handle(originalToNewFn, from, to, totrunc, NewF, mode, + TruncateGenerator handle(originalToNewFn, truncation, totrunc, NewF, mode, *this); for (auto &BB : *totrunc) for (auto &I : BB) diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 1e9bf216b6e9..4bb61e94c8ef 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -42,6 +42,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" #include "ActivityAnalysis.h" #include "FunctionUtils.h" @@ -287,6 +288,17 @@ getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { } enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; +[[maybe_unused]] static const char *truncateModeStr(TruncateMode mode) { + switch (mode) { + case TruncMemMode: + return "mem"; + case TruncOpMode: + return "op"; + case TruncOpFullModuleMode: + return "op_full_module"; + } + llvm_unreachable("Invalid truncation mode"); +} struct FloatRepresentation { // |_|__________|_________________| @@ -336,6 +348,54 @@ struct FloatRepresentation { } }; +struct FloatTruncation { +private: + FloatRepresentation from, to; + +public: + FloatTruncation(FloatRepresentation From, FloatRepresentation To) + : from(From), to(To) { + if (!From.canBeBuiltin()) + llvm::report_fatal_error("Float truncation `from` type is not builtin."); + if (From.exponentWidth < To.exponentWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider exponent than `to`."); + if (From.significandWidth < To.significandWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider wsignificand than `to`."); + if (From == To) + llvm::report_fatal_error( + "Float truncation `from` and `to` type must not be the same."); + } + FloatRepresentation getTo() { return to; } + unsigned getFromTypeWidth() { return from.getTypeWidth(); } + unsigned getToTypeWidth() { return to.getTypeWidth(); } + llvm::Type *getFromType(llvm::LLVMContext &ctx) { + return from.getBuiltinType(ctx); + } + bool isToMPFR() { return !to.canBeBuiltin(); } + llvm::Type *getToType(llvm::LLVMContext &ctx) { + if (to.canBeBuiltin()) { + return to.getBuiltinType(ctx); + } else { + assert(isToMPFR()); + // Currently we do not support TruncMemMode for MPFR, and we provide + // runtime wrappers around MPFR for each builtin `from` type + return from.getBuiltinType(ctx); + } + } + bool operator==(const FloatTruncation &other) const { + return from == other.from && to == other.to; + } + bool operator<(const FloatTruncation &other) const { + return std::tuple(from, to) < std::tuple(other.from, other.to); + } + std::string mangleTruncation() const { + return from.to_string() + "to" + to.to_string(); + } + std::string mangleFrom() const { return from.to_string(); } +}; + class EnzymeLogic { public: PreProcessCache PPC; @@ -583,13 +643,13 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); - using TruncateCacheKey = std::tuple; + using TruncateCacheKey = + std::tuple; std::map TruncateCachedFunctions; llvm::Function *CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, - FloatRepresentation from, - FloatRepresentation to, TruncateMode mode); + FloatTruncation truncation, + TruncateMode mode); bool CreateTruncateValue(RequestContext context, llvm::Value *addr, FloatRepresentation from, FloatRepresentation to, bool isTruncate); diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index c96efa70660a..68f0ef473a9b 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -21,13 +21,14 @@ entry: %res = call i1 %ptr(double %x, double %y) ret i1 %res } +define i1 @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_op_func(i1 (double, double)* @f, i64 64, i64 3, i64 7) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} -; CHECK: define i1 @tester(double %x, double %y) { -; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) -; CHECK-NEXT: ret i1 %res - -; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -38,7 +39,7 @@ entry: ; CHECK-DAG: %res = fcmp olt float %3, %5 ; CHECK-DAG: ret i1 %res -; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float ; CHECK-DAG: %res = fcmp olt float %enzyme_trunc, %enzyme_trunc1 diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 99568539c3f3..2299c9fb1ab3 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -1,11 +1,13 @@ ; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi ; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi +declare double @pow(double %Val, double %Power) declare double @llvm.pow.f64(double %Val, double %Power) declare double @llvm.powi.f64.i16(double %Val, i16 %power) declare void @llvm.nvvm.barrier0() define double @f(double %x, double %y) { + %res0 = call double @pow(double %x, double %y) %res1 = call double @llvm.pow.f64(double %x, double %y) %res2 = call double @llvm.powi.f64.i16(double %x, i16 2) %res = fadd double %res1, %res2 @@ -22,62 +24,93 @@ entry: %res = call double %ptr(double %x, double %y) ret double %res } -define double @tester2(double %x, double %y) { +define double @tester_op(double %x, double %y) { entry: %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y) ret double %res } +define double @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 3, i64 7) + %res = call double %ptr(double %x, double %y) + ret double %res +} -; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y) { -; CHECK-NEXT: %1 = alloca double, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %2 = bitcast double* %1 to float* -; CHECK-NEXT: %3 = load float, float* %2, align 4 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %4 = bitcast double* %1 to float* -; CHECK-NEXT: %5 = load float, float* %4, align 4 -; CHECK-NEXT: %res11 = call float @llvm.pow.f32(float %3, float %5) -; CHECK-NEXT: %6 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %6, align 4 -; CHECK-NEXT: %7 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res11, float* %7, align 4 -; CHECK-NEXT: %8 = load double, double* %1, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %9 = bitcast double* %1 to float* -; CHECK-NEXT: %10 = load float, float* %9, align 4 -; CHECK-NEXT: %res22 = call float @llvm.powi.f32.i16(float %10, i16 2) -; CHECK-NEXT: %11 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %11, align 4 -; CHECK-NEXT: %12 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res22, float* %12, align 4 -; CHECK-NEXT: %13 = load double, double* %1, align 8 -; CHECK-NEXT: store double %8, double* %1, align 8 -; CHECK-NEXT: %14 = bitcast double* %1 to float* -; CHECK-NEXT: %15 = load float, float* %14, align 4 -; CHECK-NEXT: store double %13, double* %1, align 8 -; CHECK-NEXT: %16 = bitcast double* %1 to float* -; CHECK-NEXT: %17 = load float, float* %16, align 4 -; CHECK-NEXT: %res = fadd float %15, %17 -; CHECK-NEXT: %18 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %18, align 4 -; CHECK-NEXT: %19 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res, float* %19, align 4 -; CHECK-NEXT: %20 = load double, double* %1, align 8 -; CHECK-NEXT: call void @llvm.nvvm.barrier0() -; CHECK-NEXT: ret double %20 +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res01 = call float @llvm.pow.f32(float %3, float %5) +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %res01, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %9 = bitcast double* %1 to float* +; CHECK-DAG: %10 = load float, float* %9, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %11 = bitcast double* %1 to float* +; CHECK-DAG: %12 = load float, float* %11, align 4 +; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %10, float %12) +; CHECK-DAG: %13 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %13, align 4 +; CHECK-DAG: %14 = bitcast double* %1 to float* +; CHECK-DAG: store float %res12, float* %14, align 4 +; CHECK-DAG: %15 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %16 = bitcast double* %1 to float* +; CHECK-DAG: %17 = load float, float* %16, align 4 +; CHECK-DAG: %res23 = call float @llvm.powi.f32.i16(float %17, i16 2) +; CHECK-DAG: %18 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %18, align 4 +; CHECK-DAG: %19 = bitcast double* %1 to float* +; CHECK-DAG: store float %res23, float* %19, align 4 +; CHECK-DAG: %20 = load double, double* %1, align 8 +; CHECK-DAG: store double %15, double* %1, align 8 +; CHECK-DAG: %21 = bitcast double* %1 to float* +; CHECK-DAG: %22 = load float, float* %21, align 4 +; CHECK-DAG: store double %20, double* %1, align 8 +; CHECK-DAG: %23 = bitcast double* %1 to float* +; CHECK-DAG: %24 = load float, float* %23, align 4 +; CHECK-DAG: %res = fadd float %22, %24 +; CHECK-DAG: %25 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %25, align 4 +; CHECK-DAG: %26 = bitcast double* %1 to float* +; CHECK-DAG: store float %res, float* %26, align 4 +; CHECK-DAG: %27 = load double, double* %1, align 8 +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %27 -; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y) { +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float -; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) -; CHECK-DAG: %enzyme_exp = fpext float %res12 to double +; CHECK-DAG: %res02 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) +; CHECK-DAG: %enzyme_exp = fpext float %res02 to double ; CHECK-DAG: %enzyme_trunc3 = fptrunc double %x to float -; CHECK-DAG: %res24 = call float @llvm.powi.f32.i16(float %enzyme_trunc3, i16 2) -; CHECK-DAG: %enzyme_exp5 = fpext float %res24 to double -; CHECK-DAG: %enzyme_trunc6 = fptrunc double %enzyme_exp to float -; CHECK-DAG: %enzyme_trunc7 = fptrunc double %enzyme_exp5 to float -; CHECK-DAG: %res = fadd float %enzyme_trunc6, %enzyme_trunc7 -; CHECK-DAG: %enzyme_exp8 = fpext float %res to double +; CHECK-DAG: %enzyme_trunc4 = fptrunc double %y to float +; CHECK-DAG: %res15 = call float @llvm.pow.f32(float %enzyme_trunc3, float %enzyme_trunc4) +; CHECK-DAG: %enzyme_exp6 = fpext float %res15 to double +; CHECK-DAG: %enzyme_trunc7 = fptrunc double %x to float +; CHECK-DAG: %res28 = call float @llvm.powi.f32.i16(float %enzyme_trunc7, i16 2) +; CHECK-DAG: %enzyme_exp9 = fpext float %res28 to double +; CHECK-DAG: %enzyme_trunc10 = fptrunc double %enzyme_exp6 to float +; CHECK-DAG: %enzyme_trunc11 = fptrunc double %enzyme_exp9 to float +; CHECK-DAG: %res = fadd float %enzyme_trunc10, %enzyme_trunc11 +; CHECK-DAG: %enzyme_exp12 = fpext float %res to double +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %enzyme_exp12 + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { +; CHECK-DAG: %1 = call double @__enzyme_mpfr_64_52_func_pow(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %2 = call double @__enzyme_mpfr_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %3 = call double @__enzyme_mpfr_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7) +; CHECK-DAG: %res = call double @__enzyme_mpfr_64_52_binop_fadd(double %2, double %3, i64 3, i64 7) ; CHECK-DAG: call void @llvm.nvvm.barrier0() -; CHECK-DAG: ret double %enzyme_exp8 +; CHECK-DAG: ret double %res +; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index 365d21ab5913..afc41219fed8 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -25,10 +25,10 @@ entry: ; CHECK: define double @tester(double %x, double %y, i1 %cond) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) +; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) ; CHECK-NEXT: ret double %res -; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -44,6 +44,6 @@ entry: ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: ret double %8 -; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52_32_23_f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %res = select i1 %cond, double %x, double %y ; CHECK-DAG: ret double %res diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 19d6cf1f3a23..a57f33fcdfdb 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -17,25 +17,20 @@ entry: call void %ptr(double* %data) ret void } - -define void @tester2(double* %data) { +define void @tester_op(double* %data) { entry: %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 32) call void %ptr(double* %data) ret void } +define void @tester_op_mpfr(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 3, i64 7) + call void %ptr(double* %data) + ret void +} -; CHECK: define void @tester(double* %data) -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %data) -; CHECK-NEXT: ret void - -; CHECK: define void @tester2(double* %data) { -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %data) -; CHECK-NEXT: ret void - -; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52_32_23_f(double* %x) +; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52to32_23_f(double* %x) ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: store double %y, double* %1, align 8 @@ -53,7 +48,7 @@ entry: ; CHECK-DAG: store double %8, double* %x, align 8 ; CHECK-DAG: ret void -; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52_32_23_f(double* %x) { +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: %enzyme_trunc = fptrunc double %y to float ; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float @@ -61,3 +56,9 @@ entry: ; CHECK-DAG: %enzyme_exp = fpext float %m to double ; CHECK-DAG: store double %enzyme_exp, double* %x, align 8 ; CHECK-DAG: ret void + +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: %m = call double @__enzyme_mpfr_64_52_binop_fmul(double %y, double %y, i64 3, i64 7) +; CHECK-DAG: store double %m, double* %x, align 8 +; CHECK-DAG: ret void diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index ad5df438842f..39e5965bda0d 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -1,12 +1,26 @@ // Baseline -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli -)" == "900000000.560000" ] ; fi + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli - | FileCheck --check-prefix BASELINE %s; fi +// BASELINE: 900000000.560000 + // Truncated -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli -)" == "900000000.000000" ] ; fi -// RUN: if [ %llvmver -ge 12 ]; then [ "$(%clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli -)" == "900000000.000000" ] ; fi + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli - | FileCheck --check-prefix TO_32 %s; fi +// TO_32: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli - | FileCheck --check-prefix TO_28_23 %s; fi +// TO_28_23: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DENZYME_TEST_TO_MPFR -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr; %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi +// TO_3_7: 897581056.000000 #include +#ifdef ENZYME_TEST_TO_MPFR +#include +#endif + #include "../test_utils.h" #define N 10 diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 0b8a0f831d6e..0cc5e6f28f38 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -16,6 +16,10 @@ config.llvm_shlib_ext = "@LLVM_SHLIBEXT@" config.targets_to_build = "@TARGETS_TO_BUILD@" +has_mpfr_h = "@HAS_MPFR_H@" +mpfr_lib_path = "@MPFR_LIB_PATH@" +has_mpfr = "yes" if mpfr_lib_path != "MPFR_LIB_PATH-NOTFOUND" and has_mpfr_h == "1" else "no" + ## Check the current platform with regex import re EAT_ERR_ON_X86 = ' ' @@ -112,6 +116,8 @@ if len("@ENZYME_BINARY_DIR@") == 0: config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < 15 else newPM)) config.substitutions.append(('%newLoadClangEnzyme', newPM)) +config.substitutions.append(('%hasMPFR', has_mpfr)) + # Let the main config do the real work. cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" if len("@ENZYME_SOURCE_DIR@") == 0: From b96c4439ee0be92cea74d35d3dd33ac53bf9ca5d Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 29 Feb 2024 20:49:35 -0800 Subject: [PATCH 104/106] Cleanup and Fixup MLIR reverse mode (#1771) --- enzyme/BUILD | 18 + enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 42 +- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 6 +- .../MLIR/Implementations/CFDerivatives.td | 1 + .../MLIR/Implementations/CMakeLists.txt | 7 + enzyme/Enzyme/MLIR/Implementations/Common.td | 5 + .../CoreDialectsAutoDiffImplementations.cpp | 28 +- .../CoreDialectsAutoDiffImplementations.h | 55 +++ .../FuncAutoDiffOpInterfaceImpl.cpp | 37 ++ .../MLIR/Implementations/FuncDerivatives.td | 3 + .../LLVMAutoDiffOpInterfaceImpl.cpp | 2 +- .../LinalgAutoDiffOpInterfaceImpl.cpp | 18 +- .../MemRefAutoDiffOpInterfaceImpl.cpp | 59 +-- .../SCFAutoDiffOpInterfaceImpl.cpp | 138 +++++-- .../MLIR/Interfaces/AutoDiffTypeInterface.td | 4 +- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 15 +- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 320 +++++---------- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 142 +++---- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 25 +- .../MLIR/Interfaces/GradientUtilsReverse.cpp | 233 +---------- .../MLIR/Interfaces/GradientUtilsReverse.h | 37 +- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 2 +- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 29 +- enzyme/Enzyme/MLIR/Passes/Passes.h | 2 +- enzyme/Enzyme/MLIR/Passes/Passes.td | 6 +- .../MLIR/Passes/RemoveUnusedEnzymeOps.cpp | 372 ++++++++++++------ .../MLIR/Passes/ShadowedGradientToCache.cpp | 64 --- enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp | 88 +++++ enzyme/test/MLIR/Passes/dualpush.mlir | 48 +++ enzyme/test/MLIR/ReverseMode/bbarg-order.mlir | 49 ++- enzyme/test/MLIR/ReverseMode/pow.mlir | 69 ++-- enzyme/test/MLIR/ReverseMode/square.mlir | 75 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 10 +- 33 files changed, 1101 insertions(+), 908 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td delete mode 100644 enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp create mode 100644 enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp create mode 100644 enzyme/test/MLIR/Passes/dualpush.mlir create mode 100644 enzyme/test/MLIR/ReverseMode/square.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index dc36cb4ad0c5..bfb41ed3f9ff 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -547,6 +547,23 @@ gentbl( ], ) +gentbl( + name = "func-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/FuncDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/FuncDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/FuncDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -582,6 +599,7 @@ cc_library( ":arith-derivatives", ":cf-derivatives", ":llvm-derivatives", + ":func-derivatives", ":math-derivatives", ":memref-derivatives", ":nvvm-derivatives", diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index de51913400a6..7d74fdafbdbf 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -83,12 +83,6 @@ def PopOp : Enzyme_Op<"pop"> { let results = (outs AnyType:$output); } -def ClearOp : Enzyme_Op<"clear"> { - let summary = "Remove top element from ShadowedGradient"; - let arguments = (ins AnyType : $cache); - let results = (outs ); -} - def InitOp : Enzyme_Op<"init"> { let summary = "Creat enzyme.gradient and enzyme.cache"; let arguments = (ins ); @@ -105,36 +99,28 @@ def Cache : Enzyme_Type<"Cache"> { let assemblyFormat = "`<` $type `>`"; } -def SetOp : Enzyme_Op<"set"> { - let summary = "Write to gradient"; - let arguments = (ins AnyType : $gradient, AnyType : $value); - let results = (outs ); -} - -def GetOp : Enzyme_Op<"get"> { - let summary = "Load value of gradient"; - let arguments = (ins AnyType : $gradient); - let results = (outs AnyType); -} - def Gradient : Enzyme_Type<"Gradient"> { - let summary = "Stores gradient if it cant be stroed in a value."; + let summary = "Mutable storage for accumulating gradients"; let description = [{ - "Cache for reverse pass" + Mutable storage for accumulating derivatives of immutable types (e.g. adding all the partial derivatives from users of a float64) }]; let parameters = (ins "Type":$basetype); let mnemonic = "Gradient"; let assemblyFormat = "`<` $basetype `>`"; } -def ShadowedGradient : Enzyme_Type<"ShadowedGradient"> { - let summary = "Stores gradients which need to be initialized with shadow values from the forward pass."; - let description = [{ - "Cache for reverse pass" - }]; - let parameters = (ins "Type":$basetype); - let mnemonic = "ShadowedGradient"; - let assemblyFormat = "`<` $basetype `>`"; +def SetOp : Enzyme_Op<"set"> { + let summary = "Store the current value of the gradient"; + let arguments = (ins Arg:$gradient, AnyType : $value); + let results = (outs ); +} + +def GetOp : Enzyme_Op<"get"> { + let summary = "Load current value of gradient"; + let arguments = (ins Arg:$gradient); + let results = (outs AnyType); } def AddToOp : Enzyme_Op<"addTo", [Pure, Terminator, ReturnLike]>, diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index a15b3283e10d..058f49e87ac9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -44,7 +44,7 @@ class FloatTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { return failure(); @@ -77,7 +77,7 @@ class TensorTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { return failure(); @@ -105,7 +105,7 @@ class IntegerTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { return failure(); diff --git a/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td index 8b4f41696a9d..0b522e72ccf2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td @@ -1,4 +1,5 @@ include "Common.td" +def : BranchOp<"cf", "CondBranchOp">; def : BranchOp<"cf", "BranchOp">; def : BranchOp<"cf", "SwitchOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index ab156fa1b36f..3508e71d9adc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -30,6 +30,10 @@ set(LLVM_TARGET_DEFINITIONS MathDerivatives.td) enzyme_tablegen(MathDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(MathDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS FuncDerivatives.td) +enzyme_tablegen(FuncDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(FuncDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp @@ -37,6 +41,7 @@ add_mlir_library(MLIREnzymeImplementations LLVMAutoDiffOpInterfaceImpl.cpp NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp + FuncAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp @@ -48,6 +53,7 @@ add_mlir_library(MLIREnzymeImplementations AffineDerivativesIncGen ArithDerivativesIncGen LLVMDerivativesIncGen + FuncDerivativesIncGen NVVMDerivativesIncGen SCFDerivativesIncGen CFDerivativesIncGen @@ -56,6 +62,7 @@ add_mlir_library(MLIREnzymeImplementations LINK_LIBS PUBLIC MLIRArithDialect + MLIRFuncDialect MLIRLLVMDialect MLIRMemRefDialect MLIREnzymeAutoDiffInterface diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 33d9f12c2f37..3924f4527b00 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -26,6 +26,11 @@ class MemoryIdentityOp ptrargs_, list class ReadOnlyIdentityOp ptrargs_> : MemoryIdentityOp; +class ReturnOp { + string dialect = dialect_; + string opName = opName_; +} + class BranchOp { string dialect = dialect_; string opName = opName_; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index f9b2d61b0dce..ade1be7e6406 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -15,6 +15,7 @@ #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" using namespace mlir; using namespace mlir::enzyme; @@ -143,8 +144,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( if (contains(storedVals, operand.getOperandNumber())) { if (auto iface = dyn_cast(operand.get().getType())) { - if (!iface.requiresShadow()) { - // TODO only do if mutable + if (!iface.isMutable()) { Type retTy = iface.getShadowType(); auto toret = retTy.cast().createNullValue( builder, operand.get().getLoc()); @@ -201,6 +201,29 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler( return success(); } +void mlir::enzyme::detail::returnReverseHandler(Operation *op, + OpBuilder &builder, + MGradientUtilsReverse *gutils) { + size_t num_out = 0; + for (auto act : gutils->RetDiffeTypes) { + if (act == DIFFE_TYPE::OUT_DIFF) + num_out++; + } + + size_t idx = 0; + auto args = gutils->newFunc->getRegions().begin()->begin()->getArguments(); + + for (auto &&[op, act] : llvm::zip(op->getOperands(), gutils->RetDiffeTypes)) { + if (act == DIFFE_TYPE::OUT_DIFF) { + if (!gutils->isConstantValue(op)) { + auto d_out = args[args.size() - num_out + idx]; + gutils->addToDiffe(op, d_out, builder); + } + idx++; + } + } +} + void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { auto parentOp = origTerminator->getParentOp(); @@ -401,4 +424,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerSCFDialectAutoDiffInterface(registry); enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); + enzyme::registerFuncDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 7ee0be2adb70..cbad734656b1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -52,6 +52,10 @@ void branchingForwardHandler(Operation *op, OpBuilder &builder, void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); +// Implements reverse-mode differentiation of return operations. +void returnReverseHandler(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils); + // Implements forward-mode differentiation of read-only (including read-none) // operations which do not perform computation LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, @@ -104,6 +108,44 @@ class AutoDiffUsingRegionTerminator } }; +template +class NoopRevAutoDiffInterface + : public ReverseAutoDiffOpInterface::ExternalModel< + NoopRevAutoDiffInterface, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const {} + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + +template +class ReturnRevAutoDiffInterface + : public ReverseAutoDiffOpInterface::ExternalModel< + ReturnRevAutoDiffInterface, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + returnReverseHandler(op, builder, gutils); + } + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + // Implements the forward autodiff interface for operations which are // read only and identity like (aka not computing sin of mem read). template @@ -166,12 +208,24 @@ void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) { template void registerAutoDiffUsingBranchInterface(MLIRContext &context) { OpTy::template attachInterface>(context); + OpTy::template attachInterface>( + context); } // Registers AutoDiffUsingRegionTerminator for the given op. template void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { OpTy::template attachInterface>( context); + OpTy::template attachInterface>( + context); +} +// Registers registerAutoDiffUsingReturnInterface for the given op. +template +void registerAutoDiffUsingReturnInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); + OpTy::template attachInterface>( + context); } // Registers AutoDiffUsingMemoryIdentity for the given op. template @@ -199,6 +253,7 @@ void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); +void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..dddce795adfd --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,37 @@ +//===- FuncAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/Ops.h" +#include "mlir/IR/TypeSupport.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/FuncDerivatives.inc" +} // namespace + +void mlir::enzyme::registerFuncDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td new file mode 100644 index 000000000000..005246887fdf --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td @@ -0,0 +1,3 @@ +include "Common.td" + +def : ReturnOp<"func", "ReturnOp">; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 472cd0a43d61..264278e97ef9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -55,7 +55,7 @@ class PointerTypeInterface return self; } - bool requiresShadow(Type self) const { return true; } + bool isMutable(Type self) const { return true; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 8ecb851af90b..db9f4b08d6e1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -133,7 +133,7 @@ struct GenericOpInterfaceReverse linalgOp.getNumLoops(), utils::IteratorType::parallel}; for (OpOperand &output : linalgOp.getDpsInitsMutable()) { - if (!gutils->hasInvertPointer(output.get())) { + if (gutils->isConstantValue(output.get())) { continue; } indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&output)); @@ -143,7 +143,7 @@ struct GenericOpInterfaceReverse } for (OpOperand *input : linalgOp.getDpsInputOperands()) { - if (!gutils->hasInvertPointer(input->get())) { + if (gutils->isConstantValue(input->get())) { continue; } indexingMaps.push_back(linalgOp.getMatchingIndexingMap(input)); @@ -165,8 +165,13 @@ struct GenericOpInterfaceReverse StringAttr()); int numInputs = inputs.size(); - auto buildFuncReturnOp = [numInputs](OpBuilder &builder, Location loc, - SmallVector retargs) { + auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder, + Block *oBB) { + auto loc = oBB->rbegin()->getLoc(); + SmallVector retargs; + for (auto arg : oBB->getArguments()) { + retargs.push_back(gutils->invertPointerM(arg, builder)); + } builder.create( loc, ValueRange{retargs}.take_front(numInputs)); return; @@ -192,9 +197,8 @@ struct GenericOpInterfaceReverse return std::make_pair(pushCache, popCache); }; - gutils->Logic.differentiate( - gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(), - /*parentRegion=*/false, buildFuncReturnOp, hook); + gutils->Logic.differentiate(gutils, *linalgOp.getBlock()->getParent(), + adjoint.getRegion(), buildFuncReturnOp, hook); auto newOpYield = cast( cast(newOp).getBodyRegion().front().getTerminator()); diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 1b811caf1b81..cd21c5c548b9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -41,8 +41,8 @@ struct LoadOpInterfaceReverse Value memref = loadOp.getMemref(); if (auto iface = dyn_cast(loadOp.getType())) { - if (gutils->hasInvertPointer(loadOp) && - gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(loadOp) && + !gutils->isConstantValue(memref)) { Value gradient = gutils->invertPointerM(loadOp, builder); Value memrefGradient = gutils->invertPointerM(memref, builder); @@ -70,8 +70,8 @@ struct LoadOpInterfaceReverse Value memref = loadOp.getMemref(); ValueRange indices = loadOp.getIndices(); if (auto iface = dyn_cast(loadOp.getType())) { - if (gutils->hasInvertPointer(loadOp) && - gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(loadOp) && + !gutils->isConstantValue(memref)) { OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); SmallVector caches; for (Value v : indices) { @@ -104,28 +104,33 @@ struct StoreOpInterfaceReverse Value memref = storeOp.getMemref(); // ValueRange indices = storeOp.getIndices(); - if (auto iface = dyn_cast(val.getType())) { - if (gutils->hasInvertPointer(memref)) { - OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); + auto iface = cast(val.getType()); - Value memrefGradient = gutils->invertPointerM(memref, builder); + if (!gutils->isConstantValue(memref)) { + OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); - SmallVector retrievedArguments; - for (Value cache : caches) { - Value retrievedValue = gutils->popCache(cache, builder); - retrievedArguments.push_back(retrievedValue); - } + Value memrefGradient = gutils->invertPointerM(memref, builder); - Value loadedGradient = - builder.create(storeOp.getLoc(), memrefGradient, - ArrayRef(retrievedArguments)); - Value addedGradient = loadedGradient; - if (gutils->hasInvertPointer(val)) { - Value gradient = gutils->invertPointerM(val, builder); - addedGradient = iface.createAddOp(builder, storeOp.getLoc(), gradient, - loadedGradient); + SmallVector retrievedArguments; + for (Value cache : caches) { + Value retrievedValue = gutils->popCache(cache, builder); + retrievedArguments.push_back(retrievedValue); + } + + if (!iface.isMutable()) { + if (!gutils->isConstantValue(val)) { + Value loadedGradient = builder.create( + storeOp.getLoc(), memrefGradient, + ArrayRef(retrievedArguments)); + gutils->addToDiffe(val, loadedGradient, builder); } - gutils->mapInvertPointer(val, addedGradient, builder); + + auto zero = + cast(gutils->getShadowType(val.getType())) + .createNullValue(builder, op->getLoc()); + + builder.create(storeOp.getLoc(), zero, memrefGradient, + ArrayRef(retrievedArguments)); } } } @@ -137,7 +142,7 @@ struct StoreOpInterfaceReverse ValueRange indices = storeOp.getIndices(); Value val = storeOp.getValue(); if (auto iface = dyn_cast(val.getType())) { - if (gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(memref)) { OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); SmallVector caches; for (Value v : indices) { @@ -175,13 +180,13 @@ struct SubViewOpInterfaceReverse MGradientUtilsReverse *gutils) const { auto subviewOp = cast(op); auto newSubviewOp = cast(gutils->getNewFromOriginal(op)); - if (gutils->hasInvertPointer(subviewOp.getSource())) { + if (!gutils->isConstantValue(subviewOp.getSource())) { Value shadow = builder.create( op->getLoc(), newSubviewOp.getType(), gutils->invertPointerM(subviewOp.getSource(), builder), newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(), newSubviewOp.getMixedStrides()); - gutils->mapShadowValue(subviewOp, shadow, builder); + gutils->setDiffe(subviewOp, shadow, builder); } } }; @@ -205,13 +210,13 @@ class MemRefTypeInterface return self; } - bool requiresShadow(Type self) const { return true; } + bool isMutable(Type self) const { return true; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { auto MT = cast(self); if (auto iface = dyn_cast(MT.getElementType())) { - if (!iface.requiresShadow()) { + if (!iface.isMutable()) { Value zero = iface.createNullValue(builder, loc); builder.create(loc, zero, val); } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index bab415ff247f..72fbe2106a53 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -40,45 +40,113 @@ struct ForOpInterfaceReverse SmallVector caches) const { auto forOp = cast(op); - SmallVector nArgs; - for (Value v : forOp.getResults()) { - if (auto iface = dyn_cast(v.getType())) { - if (gutils->hasInvertPointer(v)) { - nArgs.push_back(gutils->invertPointerM(v, builder)); - } else { - nArgs.push_back(iface.createNullValue(builder, v.getLoc())); + // Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 + SmallVector resDiffes; + for (OpResult v : forOp.getResults()) { + if (!gutils->isConstantValue(v)) { + auto autoDiffType = cast(v.getType()); + if (!autoDiffType.isMutable()) { + auto prev = gutils->diffe(v, builder); + gutils->zeroDiffe(v, builder); + resDiffes.push_back(prev); + continue; } } + resDiffes.push_back(nullptr); } - auto repFor = builder.create( - forOp.getLoc(), gutils->popCache(caches[0], builder), - gutils->popCache(caches[1], builder), - gutils->popCache(caches[2], builder), nArgs); // TODO - repFor.getRegion().begin()->erase(); - - auto buildFuncReturnOp = [](OpBuilder &builder, Location loc, - SmallVector retargs) { - builder.create(loc, retargs); - return; - }; - - gutils->Logic.differentiate(gutils, forOp.getRegion(), repFor.getRegion(), - /*parentRegion=*/false, buildFuncReturnOp, - nullptr); - - // Insert the index which is carried by the scf for op. - Type indexType = IndexType::get(builder.getContext()); - repFor.getRegion().insertArgument((unsigned)0, indexType, forOp.getLoc()); - - for (const auto &[iterOperand, adjResult] : - llvm::zip(forOp.getInitArgs(), repFor.getResults())) { - if (gutils->hasInvertPointer(iterOperand)) { - auto autoDiffType = cast(iterOperand.getType()); - Value before = gutils->invertPointerM(iterOperand, builder); - Value after = autoDiffType.createAddOp(builder, forOp.getLoc(), before, - adjResult); - gutils->mapInvertPointer(iterOperand, after, builder); + for (auto ® : op->getRegions()) { + auto termIface = + cast(reg.begin()->getTerminator()); + + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + for (auto &successor : successors) { + if (!successor.isParent()) + continue; + OperandRange operandRange = termIface.getSuccessorOperands(successor); + assert(operandRange.size() == resDiffes.size()); + + // There is an assumption here that there is only regions that branch to + // the successor. Specifically, otherwise we would need to + // gutils->addToDiffe select (if came from that result) + for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) { + if (!post) + continue; + if (!gutils->isConstantValue(prev)) + gutils->addToDiffe(prev, post, builder); + } + } + } + // End Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 + + auto start = gutils->popCache(caches[0], builder); + auto end = gutils->popCache(caches[1], builder); + auto step = gutils->popCache(caches[2], builder); + + auto repFor = builder.create(forOp.getLoc(), start, end, step, + ArrayRef()); + // erase scf yield + repFor.getBody()->begin()->erase(); + + for (auto &&[oldReg, newReg] : + llvm::zip(op->getRegions(), repFor->getRegions())) { + + // This code assumes at most one terminating block for each region (lest + // the append happen multiple times) + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { + auto loc = oBB->rbegin()->getLoc(); + + auto idx = repFor.getInductionVar(); + + auto lhs = builder.create(loc, idx, step); + + // This needs to know a condition describing which predecessor this will + // return to, to select the right value Here we use the condition i + + // step >= end to determine the last iteration + + auto condition = builder.create( + loc, arith::CmpIPredicate::sge, lhs, end); + + for (auto [arg, init_arg] : + llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + auto diffe = gutils->diffe(arg, builder); + gutils->zeroDiffe(arg, builder); + + auto zero = cast(diffe.getType()) + .createNullValue(builder, loc); + auto outside = + builder.create(loc, condition, diffe, zero); + auto inside = + builder.create(loc, condition, zero, diffe); + + // For each predecessor, if we came from that predecessor += the + // shadow of the arg [after zero'ing] + if (!gutils->isConstantValue(init_arg)) { + gutils->addToDiffe(init_arg, outside, builder); + } + + if (!gutils->isConstantValue(arg)) { + gutils->addToDiffe(arg, inside, builder); + } + } + } + builder.create(loc); + }; + + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->mapReverseModeBlocks.map(&oBB, &revBB); + } + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->Logic.visitChildren(&oBB, &revBB, gutils); + Block *newBB = gutils->getNewFromOriginal(&oBB); + gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, + buildFuncReturnOp); } } } diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td index 956f616751dc..0f6d38b4ffa4 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td @@ -61,10 +61,10 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { >, InterfaceMethod< /*desc=*/[{ - Returns if the Type needs to be cleared. + Returns whether the type is mutable in place or not. }], /*retTy=*/"bool", - /*methodName=*/"requiresShadow", + /*methodName=*/"isMutable", /*args=*/(ins ) > ]; diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index 223f604eaabb..dc98a5c4c6a6 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -10,8 +10,7 @@ namespace mlir { namespace enzyme { -typedef void(buildReturnFunction)(OpBuilder &, Location, - SmallVector); +typedef void(buildReturnFunction)(OpBuilder &, mlir::Block *); class MGradientUtilsReverse; @@ -125,24 +124,22 @@ class MEnzymeLogic { void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); - void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, - llvm::function_ref buildReturnOp, - bool parentRegion); + void + handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, + MGradientUtilsReverse *gutils, + llvm::function_ref buildReturnOp); void visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); void visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); bool visitChildCustom(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); - void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, bool parentRegion); void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); SmallVector getDominatorToposort(MGradientUtilsReverse *gutils, Region ®ion); void differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, - Region &newRegion, bool parentRegion, + Region &newRegion, llvm::function_ref buildFuncRetrunOp, std::function(Type)> cacheCreator); }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 39de939bc27a..f39766a0a639 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -11,8 +11,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Dominance.h" -#include "llvm/ADT/BreadthFirstIterator.h" #include "EnzymeLogic.h" #include "Interfaces/GradientUtils.h" @@ -22,71 +20,18 @@ using namespace mlir; using namespace mlir::enzyme; -SmallVector -MEnzymeLogic::getDominatorToposort(MGradientUtilsReverse *gutils, - Region ®ion) { - SmallVector dominatorToposortBlocks; - if (region.hasOneBlock()) { - dominatorToposortBlocks.push_back(&*(region.begin())); - } else { - auto dInfo = mlir::detail::DominanceInfoBase(nullptr); - llvm::DominatorTreeBase &dt = - dInfo.getDomTree(&(gutils->oldFunc.getFunctionBody())); - auto root = dt.getNode(&*(region.begin())); +void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, + MGradientUtilsReverse *gutils) { + if (oBB->getNumSuccessors() == 0) { + Operation *returnStatement = newBB->getTerminator(); + gutils->erase(returnStatement); - for (llvm::DomTreeNodeBase *node : llvm::breadth_first(root)) { - dominatorToposortBlocks.push_back(node->getBlock()); - } - } - return dominatorToposortBlocks; -} + OpBuilder forwardToBackwardBuilder(newBB, newBB->end()); -void MEnzymeLogic::mapInvertArguments(Block *oBB, Block *reverseBB, - MGradientUtilsReverse *gutils) { - OpBuilder builder(reverseBB, reverseBB->begin()); - for (int i = 0; i < (int)gutils->mapBlockArguments[oBB].size(); i++) { - auto x = gutils->mapBlockArguments[oBB][i]; - if (auto iface = x.second.getType().dyn_cast()) { - Value added = reverseBB->getArgument(i); - if (gutils->hasInvertPointer(x.second)) { - added = iface.createAddOp(builder, x.second.getLoc(), added, - gutils->invertPointerM(x.second, builder)); - } - gutils->mapInvertPointer(x.second, added, builder); - } - } -} + Operation *newBranchOp = forwardToBackwardBuilder.create( + oBB->getTerminator()->getLoc(), reverseBB); -void MEnzymeLogic::handleReturns(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, - bool parentRegion) { - if (oBB->getNumSuccessors() == 0) { - if (parentRegion) { - Operation *returnStatement = newBB->getTerminator(); - gutils->erase(returnStatement); - - OpBuilder forwardToBackwardBuilder(newBB, newBB->end()); - gutils->mapInvertPointer( - oBB->getTerminator()->getOperand(0), - gutils->newFunc.getArgument(gutils->newFunc.getNumArguments() - 1), - forwardToBackwardBuilder); // TODO handle multiple return values - Operation *newBranchOp = forwardToBackwardBuilder.create( - oBB->getTerminator()->getLoc(), reverseBB); - - gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp; - } else { - Operation *terminator = oBB->getTerminator(); - OpBuilder builder(reverseBB, reverseBB->begin()); - - int i = 0; - for (OpOperand &operand : terminator->getOpOperands()) { - Value val = operand.get(); - if (auto iface = val.getType().dyn_cast()) { - gutils->mapInvertPointer(val, reverseBB->getArgument(i), builder); - i++; - } - } - } + gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp; } } @@ -135,7 +80,7 @@ bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder, SmallVector args; for (Value opResult : op->getResults()) { - if (gutils->hasInvertPointer(opResult)) { + if (!gutils->isConstantValue(opResult)) { Value invertValue = gutils->invertPointerM(opResult, builder); args.push_back(invertValue); } @@ -152,7 +97,7 @@ bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder, func::CallOp dCI = builder.create(op->getLoc(), srDiffe, resultTypes, args); for (int i = 0; i < (int)op->getNumOperands(); i++) { - gutils->mapInvertPointer(op->getOperand(i), dCI.getResult(i), builder); + gutils->setDiffe(op->getOperand(i), dCI.getResult(i), builder); } return true; @@ -165,7 +110,8 @@ Create reverse mode adjoint for an operation. */ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) { - if (llvm::all_of(op->getResults(), + if ((op->getBlock()->getTerminator() != op) && + llvm::all_of(op->getResults(), [gutils](Value v) { return gutils->isConstantValue(v); }) && gutils->isConstantInstruction(op)) { return; @@ -173,12 +119,14 @@ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, if (auto ifaceOp = dyn_cast(op)) { SmallVector caches = ifaceOp.cacheValues(gutils); ifaceOp.createReverseModeAdjoint(builder, gutils, caches); - + return; + /* for (int indexResult = 0; indexResult < (int)op->getNumResults(); indexResult++) { Value result = op->getResult(indexResult); gutils->clearValue(result, builder); } + */ } op->emitError() << "could not compute the adjoint for this operation " << *op; } @@ -201,176 +149,107 @@ void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, void MEnzymeLogic::handlePredecessors( Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, - llvm::function_ref buildReturnOp, bool parentRegion) { + llvm::function_ref buildReturnOp) { OpBuilder revBuilder(reverseBB, reverseBB->end()); if (oBB->hasNoPredecessors()) { - SmallVector retargs; - // We need different handling on the top level due to - // the presence of duplicated args since we don't yet have activity analysis - if (parentRegion) { - assert(gutils->ArgDiffeTypes.size() == - gutils->oldFunc.getNumArguments() && - "Mismatch of activity array size vs # original function args"); - for (const auto &[diffeType, oldArg] : - llvm::zip(gutils->ArgDiffeTypes, oBB->getArguments())) { - if (diffeType == DIFFE_TYPE::OUT_DIFF) { - retargs.push_back(gutils->invertPointerM(oldArg, revBuilder)); - } - } - } else { - for (auto arg : oBB->getArguments()) { - if (gutils->hasInvertPointer(arg)) { - retargs.push_back(gutils->invertPointerM(arg, revBuilder)); - } - } - } - buildReturnOp(revBuilder, oBB->rbegin()->getLoc(), retargs); + buildReturnOp(revBuilder, oBB); } else { + Location loc = oBB->rbegin()->getLoc(); + // TODO remove dependency on CF dialect + + Value cache = gutils->insertInit(gutils->getIndexCacheType()); + + Value flag = + revBuilder.create(loc, gutils->getIndexType(), cache); + + Block *defaultBlock = nullptr; + SmallVector blocks; SmallVector indices; - SmallVector> arguments; - SmallVector defaultArguments; - Block *defaultBlock = nullptr; - for (auto pair : llvm::enumerate(oBB->getPredecessors())) { - auto predecessor = pair.value(); - auto idx = pair.index(); - Block *predecessorRevMode = - gutils->mapReverseModeBlocks.lookupOrNull(predecessor); - - SmallVector operands; - auto argumentsIt = gutils->mapBlockArguments.find(predecessor); - if (argumentsIt != gutils->mapBlockArguments.end()) { - for (auto operandOld : argumentsIt->second) { - if (oBB == operandOld.first.getParentBlock() && - gutils->hasInvertPointer(operandOld.first)) { - operands.push_back( - gutils->invertPointerM(operandOld.first, revBuilder)); - } else { - if (auto iface = operandOld.first.getType() - .dyn_cast()) { - Value nullValue = - iface.createNullValue(revBuilder, oBB->rbegin()->getLoc()); - operands.push_back(nullValue); - } else { - llvm_unreachable("no canonial null value found"); - } - } - } - } - if (idx != 0) { - blocks.push_back(predecessorRevMode); - indices.push_back(APInt(32, idx - 1)); - arguments.emplace_back(std::move(operands)); - } else { - defaultBlock = predecessorRevMode; - defaultArguments = operands; + + OpBuilder newBuilder(newBB, newBB->begin()); + + SmallVector diffes; + for (auto arg : oBB->getArguments()) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + diffes.push_back(gutils->diffe(arg, revBuilder)); + gutils->zeroDiffe(arg, revBuilder); + continue; } + diffes.push_back(nullptr); } - // Clear invert pointers of all arguments with gradient - for (auto argument : oBB->getArguments()) { - if (gutils->hasInvertPointer(argument)) { - auto iface = argument.getType().cast(); - Value nullValue = iface.createNullValue(revBuilder, argument.getLoc()); - gutils->mapInvertPointer(argument, nullValue, revBuilder); + for (auto [idx, pred] : llvm::enumerate(oBB->getPredecessors())) { + auto reversePred = gutils->mapReverseModeBlocks.lookupOrNull(pred); + + Block *newPred = gutils->getNewFromOriginal(pred); + + OpBuilder predecessorBuilder(newPred->getTerminator()); + + Value pred_idx_c = + predecessorBuilder.create(loc, idx - 1, 32); + predecessorBuilder.create(loc, cache, pred_idx_c); + + if (idx == 0) { + defaultBlock = reversePred; + + } else { + indices.push_back(APInt(32, idx - 1)); + blocks.push_back(reversePred); } - } - Location loc = oBB->rbegin()->getLoc(); - // Remove Dependency to CF dialect - if (std::next(oBB->getPredecessors().begin()) == - oBB->getPredecessors().end()) { - // If there is only one block we can directly create a branch for - // simplicity sake - revBuilder.create(loc, defaultBlock, defaultArguments); - } else { - Value cache = gutils->insertInit(gutils->getIndexCacheType()); - Value flag = - revBuilder.create(loc, gutils->getIndexType(), cache); - - SmallVector argumentRanges; - for (const auto &a : arguments) - argumentRanges.emplace_back(a); - revBuilder.create( - loc, flag, defaultBlock, defaultArguments, ArrayRef(indices), - ArrayRef(blocks), argumentRanges); - - Value origin = newBB->addArgument(gutils->getIndexType(), loc); - - OpBuilder newBuilder(newBB, newBB->begin()); - newBuilder.create(loc, cache, origin); - - int j = 0; - for (Block *predecessor : oBB->getPredecessors()) { - Block *newPredecessor = gutils->getNewFromOriginal(predecessor); - - OpBuilder predecessorBuilder(newPredecessor, - std::prev(newPredecessor->end())); - Value indicator = - predecessorBuilder.create(loc, j++, 32); - - Operation *terminator = newPredecessor->getTerminator(); - if (auto binst = dyn_cast(terminator)) { - for (unsigned i = 0; i < terminator->getNumSuccessors(); i++) { - if (terminator->getSuccessor(i) == newBB) { - SuccessorOperands sOps = binst.getSuccessorOperands(i); - sOps.append(indicator); + auto term = pred->getTerminator(); + if (auto iface = dyn_cast(term)) { + for (auto &op : term->getOpOperands()) + if (auto blk_idx = + iface.getSuccessorBlockArgument(op.getOperandNumber())) + if ((*blk_idx).getOwner() == oBB) { + auto idx = (*blk_idx).getArgNumber(); + if (diffes[idx]) { + + Value rev_idx_c = + revBuilder.create(loc, idx - 1, 32); + + auto to_prop = revBuilder.create( + loc, + revBuilder.create( + loc, arith::CmpIPredicate::eq, flag, rev_idx_c), + diffes[idx], + cast(diffes[idx].getType()) + .createNullValue(revBuilder, loc)); + + gutils->addToDiffe(op.get(), to_prop, revBuilder); + } } - } - } else { - llvm_unreachable("invalid terminator"); - } + } else { + assert(0 && "predecessor did not implement branch op interface"); } } - } -} -void MEnzymeLogic::initializeShadowValues( - SmallVector &dominatorToposortBlocks, - MGradientUtilsReverse *gutils) { - for (auto it = dominatorToposortBlocks.begin(); - it != dominatorToposortBlocks.end(); ++it) { - Block *oBB = *it; - - if (!oBB->empty()) { - for (auto it = oBB->begin(); it != oBB->end(); ++it) { - Operation *op = &*it; - Operation *newOp = gutils->getNewFromOriginal(op); - - if (auto ifaceOp = dyn_cast(op)) { - OpBuilder builder(newOp); - ifaceOp.createShadowValues(builder, gutils); - } - } - } + revBuilder.create( + loc, flag, defaultBlock, ArrayRef(), ArrayRef(indices), + ArrayRef(blocks), + SmallVector(indices.size(), ValueRange())); } } void MEnzymeLogic::differentiate( MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, - bool parentRegion, llvm::function_ref buildFuncReturnOp, std::function(Type)> cacheCreator) { gutils->registerCacheCreatorHook(cacheCreator); auto scope = llvm::make_scope_exit( [&]() { gutils->deregisterCacheCreatorHook(cacheCreator); }); - gutils->createReverseModeBlocks(oldRegion, newRegion, parentRegion); - - SmallVector dominatorToposortBlocks = - getDominatorToposort(gutils, oldRegion); - initializeShadowValues(dominatorToposortBlocks, gutils); - - for (auto it = dominatorToposortBlocks.rbegin(); - it != dominatorToposortBlocks.rend(); ++it) { - Block *oBB = *it; - Block *newBB = gutils->getNewFromOriginal(oBB); - Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB); - mapInvertArguments(oBB, reverseBB, gutils); - handleReturns(oBB, newBB, reverseBB, gutils, parentRegion); - visitChildren(oBB, reverseBB, gutils); - handlePredecessors(oBB, newBB, reverseBB, gutils, buildFuncReturnOp, - parentRegion); + gutils->createReverseModeBlocks(oldRegion, newRegion); + + for (auto &oBB : oldRegion) { + Block *newBB = gutils->getNewFromOriginal(&oBB); + Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB); + handleReturns(&oBB, newBB, reverseBB, gutils); + visitChildren(&oBB, reverseBB, gutils); + handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp); } } @@ -394,13 +273,20 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( Region &oldRegion = gutils->oldFunc.getFunctionBody(); Region &newRegion = gutils->newFunc.getFunctionBody(); - auto buildFuncReturnOp = [](OpBuilder &builder, Location loc, - SmallVector retargs) { - builder.create(loc, retargs); + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { + SmallVector retargs; + for (auto [arg, cv] : llvm::zip(oBB->getArguments(), constants)) { + if (cv == DIFFE_TYPE::OUT_DIFF) { + retargs.push_back(gutils->diffe(arg, builder)); + } + } + builder.create(oBB->rbegin()->getLoc(), retargs); return; }; - differentiate(gutils, oldRegion, newRegion, true, buildFuncReturnOp, nullptr); + gutils->forceAugmentedReturns(); + + differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr); auto nf = gutils->newFunc; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 4598394b9e44..ebae44c9efa1 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -41,7 +41,8 @@ mlir::enzyme::MGradientUtils::MGradientUtils( originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) { + TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), + RetDiffeTypes(1, ReturnActivity) { /* for (BasicBlock &BB : *oldFunc) { @@ -170,6 +171,55 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, llvm_unreachable("could not invert pointer"); } +mlir::Value +mlir::enzyme::MDiffeGradientUtils::getDifferential(mlir::Value oval) { + auto found = differentials.lookupOrNull(oval); + if (found != nullptr) + return found; + + auto shadowty = getShadowType(oval.getType()); + OpBuilder builder(oval.getContext()); + builder.setInsertionPointToStart(initializationBlock); + + auto shadow = builder.create( + oval.getLoc(), enzyme::GradientType::get(oval.getContext(), shadowty)); + auto toset = cast(shadowty).createNullValue( + builder, oval.getLoc()); + builder.create(oval.getLoc(), shadow, toset); + + differentials.map(oval, shadow); + return shadow; +} + +void mlir::enzyme::MDiffeGradientUtils::setDiffe(mlir::Value oval, + mlir::Value toset, + OpBuilder &BuilderM) { + assert(!isConstantValue(oval)); + auto iface = oval.getType().cast(); + if (!iface.isMutable()) { + auto shadow = getDifferential(oval); + BuilderM.create(oval.getLoc(), shadow, toset); + } else { + MGradientUtils::setDiffe(oval, toset, BuilderM); + } +} + +void mlir::enzyme::MDiffeGradientUtils::zeroDiffe(mlir::Value oval, + OpBuilder &BuilderM) { + assert(!isConstantValue(oval)); + auto iface = getShadowType(oval.getType()).cast(); + assert(!iface.isMutable()); + setDiffe(oval, iface.createNullValue(BuilderM, oval.getLoc()), BuilderM); +} + +mlir::Value mlir::enzyme::MDiffeGradientUtils::diffe(mlir::Value oval, + OpBuilder &BuilderM) { + + auto shadow = getDifferential(oval); + return BuilderM.create(oval.getLoc(), + getShadowType(oval.getType()), shadow); +} + void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM) { /* @@ -226,81 +276,39 @@ void mlir::enzyme::MGradientUtils::forceAugmentedReturns() { if (isConstantValue(val)) continue; auto i = val.getArgNumber(); - mlir::Value dval; - if (i == blk->getArguments().size() - 1) - dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc()); - else - dval = nblk->insertArgument(nblk->args_begin() + i + 1, - getShadowType(val.getType()), val.getLoc()); - - invertedPointers.map(val, dval); + if (mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit || + cast(val.getType()).isMutable()) { + mlir::Value dval; + if (i == blk->getArguments().size() - 1) + dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc()); + else + dval = + nblk->insertArgument(nblk->args_begin() + i + 1, + getShadowType(val.getType()), val.getLoc()); + + invertedPointers.map(val, dval); + } } }); oldFunc.walk([&](Operation *inst) { if (inst == oldFunc) return; - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit) { - OpBuilder BuilderZ(getNewFromOriginal(inst)); - for (auto res : inst->getResults()) { - if (!isConstantValue(res)) { - mlir::Type antiTy = getShadowType(res.getType()); - auto anti = - BuilderZ.create(res.getLoc(), antiTy); - invertedPointers.map(res, anti); - } - } - return; - } - /* - - if (inst->getType()->isFPOrFPVectorTy()) - continue; //! op->getType()->isPointerTy() && - //! !op->getType()->isIntegerTy()) { - - if (!TR.query(inst)[{-1}].isPossiblePointer()) - continue; - - if (isa(inst)) { - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); - Type *antiTy = getShadowType(inst->getType()); - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, inst->getName() + "'il_phi"); - invertedPointers.insert(std::make_pair( - (const Value *)inst, InvertedPointerVH(this, anti))); - continue; - } - - if (!isa(inst)) { - continue; - } - - if (isa(inst)) { - continue; - } - - if (isConstantValue(inst)) { - continue; - } - - CallInst *op = cast(inst); - Function *called = op->getCalledFunction(); - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); - Type *antiTy = getShadowType(inst->getType()); - - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, op->getName() + "'ip_phi"); - invertedPointers.insert( - std::make_pair((const Value *)inst, InvertedPointerVH(this, anti))); + OpBuilder BuilderZ(getNewFromOriginal(inst)); + for (auto res : inst->getResults()) { + if (isConstantValue(res)) + continue; - if (called && isAllocationFunction(called->getName(), TLI)) { - anti->setName(op->getName() + "'mi"); + if (!(mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit || + cast(res.getType()).isMutable())) + continue; + mlir::Type antiTy = getShadowType(res.getType()); + auto anti = BuilderZ.create(res.getLoc(), antiTy); + invertedPointers.map(res, anti); } - */ }); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 7ef99a1014fd..32a7fe068d07 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -38,6 +38,7 @@ class MGradientUtils { bool omp; unsigned width; + SmallVector RetDiffeTypes; ArrayRef ArgDiffeTypes; mlir::Value getNewFromOriginal(const mlir::Value originst) const; @@ -68,16 +69,35 @@ class MGradientUtils { bool isConstantInstruction(mlir::Operation *v) const; bool isConstantValue(mlir::Value v) const; mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2); - void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); void forceAugmentedReturns(); Operation *cloneWithNewOperands(OpBuilder &B, Operation *op); LogicalResult visitChild(Operation *op); + + void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder); + + mlir::Type getShadowType(mlir::Type T) { + auto iface = cast(T); + return iface.getShadowType(width); + } }; class MDiffeGradientUtils : public MGradientUtils { +protected: + IRMapping differentials; + + Block *initializationBlock; + public: + mlir::Value getDifferential(mlir::Value origv); + + void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder); + + void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder); + + mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder); + MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA, MTypeResults TR, IRMapping &invertedPointers_, @@ -91,7 +111,8 @@ class MDiffeGradientUtils : public MGradientUtils { : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, constantvalues_, activevals_, ActiveReturn, constant_values, origToNew_, origToNewOps_, mode, width, - omp) {} + omp), + initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor static MDiffeGradientUtils * diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 65d0349f1476..20d34247fb3a 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -40,16 +40,7 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( invertedPointers_, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_, mode_, width, /*omp*/ false), - symbolTable(symbolTable_) { - - initInitializationBlock(invertedPointers_, ArgDiffeTypes_); -} - -// for(auto x : v.getUsers()){x->dump();} DEBUG - -bool MGradientUtilsReverse::onlyUsedInParentBlock(Value v) { - return !v.isUsedOutsideOfBlock(v.getParentBlock()); -} + symbolTable(symbolTable_) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -109,34 +100,6 @@ Value MGradientUtilsReverse::popCache(Value cache, OpBuilder &builder) { cache); } -// Gradient -Type mlir::enzyme::MGradientUtilsReverse::getGradientType(Value v) { - Type valueType = v.getType(); - return GradientType::get(v.getContext(), valueType); -} - -Value mlir::enzyme::MGradientUtilsReverse::insertInitGradient( - mlir::Value v, OpBuilder &builder) { - Type gradientType = getGradientType(v); - OpBuilder initBuilder(initializationBlock, initializationBlock->begin()); - Value gradient = initBuilder.create(v.getLoc(), gradientType); - return gradient; -} - -// Shadow Gradient -Type mlir::enzyme::MGradientUtilsReverse::getShadowedGradientType(Value v) { - Type valueType = v.getType(); - return ShadowedGradientType::get(v.getContext(), valueType); -} - -Value mlir::enzyme::MGradientUtilsReverse::insertInitShadowedGradient( - mlir::Value v, OpBuilder &builder) { - Type gradientType = getShadowedGradientType(v); - OpBuilder initBuilder(initializationBlock, initializationBlock->begin()); - Value gradient = initBuilder.create(v.getLoc(), gradientType); - return gradient; -} - Operation * mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, Operation *op) { @@ -146,205 +109,25 @@ mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, return B.clone(*op, map); } -bool mlir::enzyme::MGradientUtilsReverse::requiresShadow(Type t) { - if (auto iface = dyn_cast(t)) { - return iface.requiresShadow(); - } - return false; -} - void mlir::enzyme::MGradientUtilsReverse::addToDiffe(Value oldGradient, Value addedGradient, OpBuilder &builder) { - // TODO - Value gradient = addedGradient; - if (hasInvertPointer(oldGradient)) { - Value operandGradient = invertPointerM(oldGradient, builder); - auto iface = cast(addedGradient.getType()); - gradient = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, + assert(!isConstantValue(oldGradient)); + Value operandGradient = diffe(oldGradient, builder); + auto iface = cast(addedGradient.getType()); + auto added = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, addedGradient); - } - mapInvertPointer(oldGradient, gradient, builder); -} - -Value mlir::enzyme::MGradientUtilsReverse::diffe(Value v, OpBuilder &builder) { - return invertPointerM(v, builder); -} - -/* -The value v must have an invert pointer -*/ -Value mlir::enzyme::MGradientUtilsReverse::invertPointerM(Value v, - OpBuilder &builder) { - if (invertedPointersGlobal.contains(v)) { - Value gradient = invertedPointersGlobal.lookupOrNull(v); - Type type = gradient.getType(); - - if (GradientType gType = dyn_cast(type)) { - Value ret = builder.create(v.getLoc(), gType.getBasetype(), - gradient); - return ret; - } else { - llvm_unreachable("found invalid type"); - } - } else if (invertedPointersShadow.contains(v)) { - Value gradient = invertedPointersShadow.lookupOrNull(v); - Type type = gradient.getType(); - - if (ShadowedGradientType gType = - dyn_cast(type)) { - Value ret = builder.create(v.getLoc(), gType.getBasetype(), - gradient); - return ret; - } else { - llvm_unreachable("found invalid type"); - } - } - - llvm::errs() << " could not invert pointer v " << v << "\n"; - llvm_unreachable("could not invert pointer"); -} - -void mlir::enzyme::MGradientUtilsReverse::mapInvertPointer( - mlir::Value v, mlir::Value invertValue, OpBuilder &builder) { - if (!invertedPointersGlobal.contains(v)) { - Value g = insertInitGradient(v, builder); - invertedPointersGlobal.map(v, g); - } - Value gradient = invertedPointersGlobal.lookupOrNull(v); - builder.create(v.getLoc(), gradient, invertValue); -} - -Value mlir::enzyme::MGradientUtilsReverse::getShadowValue(mlir::Value v) { - return shadowValues.lookupOrNull(v); -} - -void mlir::enzyme::MGradientUtilsReverse::mapShadowValue(mlir::Value v, - mlir::Value shadow, - OpBuilder &builder) { - assert(!invertedPointersShadow.contains( - v)); // Shadow Values must only be mapped exactly once - - Value cache = insertInitShadowedGradient(v, builder); - invertedPointersShadow.map(v, cache); - - builder.create(v.getLoc(), cache, shadow); - - shadowValues.map(v, shadow); -} - -void mlir::enzyme::MGradientUtilsReverse::clearValue(mlir::Value v, - OpBuilder &builder) { - if (invertedPointersGlobal.contains(v)) { - if (!onlyUsedInParentBlock(v)) { // TODO is this necessary? - Value gradient = invertedPointersGlobal.lookupOrNull(v); - Type type = cast(gradient.getType()).getBasetype(); - if (auto iface = dyn_cast(type)) { - Value zero = iface.createNullValue(builder, v.getLoc()); - builder.create(v.getLoc(), gradient, zero); - } else { - llvm_unreachable( - "Type does not have an associated AutoDiffTypeInterface"); - } - } - } else if (invertedPointersShadow.contains(v)) { - Value gradient = invertedPointersShadow.lookupOrNull(v); - builder.create(v.getLoc(), gradient); - } -} - -bool mlir::enzyme::MGradientUtilsReverse::hasInvertPointer(mlir::Value v) { - return (invertedPointersGlobal.contains(v)) || - (invertedPointersShadow.contains(v)); -} - -void MGradientUtilsReverse::initInitializationBlock( - IRMapping invertedPointers_, ArrayRef argDiffeTypes) { - initializationBlock = &*(this->newFunc.getFunctionBody().begin()); - - OpBuilder initializationBuilder( - &*(this->newFunc.getFunctionBody().begin()), - this->newFunc.getFunctionBody().begin()->begin()); - - for (const auto &[val, diffe_type] : llvm::zip( - this->oldFunc.getFunctionBody().getArguments(), argDiffeTypes)) { - if (diffe_type != DIFFE_TYPE::OUT_DIFF) { - continue; - } - auto iface = dyn_cast(val.getType()); - if (!iface) { - llvm_unreachable( - "Type does not have an associated AutoDiffTypeInterface"); - } - Value zero = iface.createNullValue(initializationBuilder, val.getLoc()); - mapInvertPointer(val, zero, initializationBuilder); - } - for (auto const &x : invertedPointers_.getValueMap()) { - if (auto iface = dyn_cast(x.first.getType())) { - if (iface.requiresShadow()) { - mapShadowValue(x.first, x.second, - initializationBuilder); // This may create an unnecessary - // ShadowedGradient which could - // be avoidable TODO - } else { - mapInvertPointer(x.first, x.second, initializationBuilder); - } - } else { - llvm_unreachable("TODO not implemented"); - } - } + setDiffe(oldGradient, added, builder); } void MGradientUtilsReverse::createReverseModeBlocks(Region &oldFunc, - Region &newFunc, - bool isParentRegion) { + Region &newFunc) { for (auto it = oldFunc.getBlocks().rbegin(); it != oldFunc.getBlocks().rend(); ++it) { Block *block = &*it; Block *reverseBlock = new Block(); - - SmallVector> - reverseModeArguments; // Argument, Assigned value (2. is technically not - // necessary but simplifies code a lot) - - // Add reverse mode Arguments to Block - Operation *term = block->getTerminator(); - mlir::BranchOpInterface brOp = dyn_cast(term); - bool returnLike = term->hasTrait(); - if (brOp) { - for (int i = 0; i < (int)term->getNumSuccessors(); i++) { - SuccessorOperands sOps = brOp.getSuccessorOperands(i); - Block *successorBlock = term->getSuccessor(i); - - assert(successorBlock->getNumArguments() == sOps.size()); - for (int j = 0; j < (int)sOps.size(); j++) { - // Check if the argument needs a gradient - if (auto iface = successorBlock->getArgument(j) - .getType() - .dyn_cast()) { - reverseModeArguments.push_back(std::pair( - successorBlock->getArgument(j), sOps[j])); - } - } - } - for (auto it : reverseModeArguments) { - reverseBlock->addArgument(it.second.getType(), it.second.getLoc()); - } - - mapBlockArguments[block] = reverseModeArguments; - } else if (returnLike) { - if (!isParentRegion) { - for (OpOperand &operand : term->getOpOperands()) { - Value val = operand.get(); - if (auto iface = val.getType().dyn_cast()) { - reverseBlock->addArgument(val.getType(), val.getLoc()); - } - } - } - } - - mapReverseModeBlocks.map(block, reverseBlock); newFunc.getBlocks().insert(newFunc.end(), reverseBlock); + mapReverseModeBlocks.map(block, reverseBlock); } } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index 3d5d4d147269..d3b2e818391f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -34,32 +34,13 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_); - IRMapping invertedPointersGlobal; - IRMapping invertedPointersShadow; - IRMapping shadowValues; - Block *initializationBlock; - IRMapping mapReverseModeBlocks; - DenseMap>> mapBlockArguments; SymbolTableCollection &symbolTable; - bool hasInvertPointer(mlir::Value v); - mlir::Value invertPointerM(mlir::Value v, OpBuilder &builder); - mlir::Value diffe(mlir::Value v, OpBuilder &builder); - void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder); - void mapInvertPointer(mlir::Value v, mlir::Value invertValue, - OpBuilder &builder); - - mlir::Value getShadowValue(mlir::Value v); - void mapShadowValue(mlir::Value v, mlir::Value invertValue, - OpBuilder &builder); - void clearValue(mlir::Value v, OpBuilder &builder); - - void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); Type getIndexType(); Value insertInit(Type t); @@ -75,27 +56,11 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { Type getIndexCacheType(); Value initAndPushCache(Value v, OpBuilder &builder); - // Gradient - Type getGradientType(Value t); - Value insertInitGradient(mlir::Value v, OpBuilder &builder); - - // ShadowedGradient - Type getShadowedGradientType(Value t); - Value insertInitShadowedGradient(mlir::Value v, OpBuilder &builder); - - bool requiresShadow(Type t); - - void initInitializationBlock(IRMapping invertedPointers_, - ArrayRef argDiffeTypes); - - bool onlyUsedInParentBlock(Value v); - Operation *cloneWithNewOperands(OpBuilder &B, Operation *op); Value popCache(Value cache, OpBuilder &builder); - void createReverseModeBlocks(Region &oldFunc, Region &newFunc, - bool isParentRegion = false); + void createReverseModeBlocks(Region &oldFunc, Region &newFunc); static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 8a18dc33d195..90e9506d2bb1 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -9,7 +9,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms PrintActivityAnalysis.cpp PrintAliasAnalysis.cpp EnzymeToMemRef.cpp - ShadowedGradientToCache.cpp + SimplifyMath.cpp AddToOpToIndexAndLoad.cpp AddToOpToSplit.cpp RemoveUnusedEnzymeOps.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index e43db26b21a2..de7328dcb186 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -31,6 +31,17 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; + static DIFFE_TYPE mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { + DIFFE_TYPE retType = DIFFE_TYPE::CONSTANT; + if (fn.getNumResults() != 0) { + if (mode == DerivativeMode::ReverseModeCombined) + retType = DIFFE_TYPE::OUT_DIFF; + else + retType = DIFFE_TYPE::DUP_ARG; + } + return retType; + } + template LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { std::vector constants; @@ -60,12 +71,11 @@ struct DifferentiatePass : public DifferentiatePassBase { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - DIFFE_TYPE retType = - fn.getNumResults() == 0 ? DIFFE_TYPE::CONSTANT : DIFFE_TYPE::DUP_ARG; + auto mode = DerivativeMode::ForwardMode; + DIFFE_TYPE retType = mode_from_fn(fn, mode); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); - auto mode = DerivativeMode::ForwardMode; bool freeMemory = true; size_t width = 1; @@ -118,19 +128,18 @@ struct DifferentiatePass : public DifferentiatePassBase { truei++; } - // Add the return gradient - mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; - args.push_back(res); - auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - DIFFE_TYPE retType = - fn.getNumResults() == 0 ? DIFFE_TYPE::CONSTANT : DIFFE_TYPE::DUP_ARG; + auto mode = DerivativeMode::ReverseModeCombined; + DIFFE_TYPE retType = mode_from_fn(fn, mode); + + // Add the return gradient + mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; + args.push_back(res); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); - auto mode = DerivativeMode::ReverseModeGradient; bool freeMemory = true; size_t width = 1; diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 80c88373090c..c4eff488c332 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -30,7 +30,7 @@ std::unique_ptr createPrintAliasAnalysisPass(); std::unique_ptr createEnzymeToMemRefPass(); -std::unique_ptr createShadowedGradientToCachePass(); +std::unique_ptr createMathematicSimplificationPass(); std::unique_ptr createAddToOpToIndexAndLoadPass(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index d2a04fc37412..c1617eaf4af0 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -141,9 +141,9 @@ def EnzymeOpsToMemRefPass : Pass<"convert-enzyme-to-memref"> { let constructor = "mlir::enzyme::createEnzymeToMemRefPass()"; } -def ShadowedGradientToCachePass : Pass<"convert-enzyme-shadowed-gradient-to-cache"> { - let summary = "Convert Enzyme Shadowed Gradient to Cache Ops"; - let constructor = "mlir::enzyme::createShadowedGradientToCachePass()"; +def MathematicSimplificationPass : Pass<"enzyme-simplify-math"> { + let summary = "Simplify basic mathematical operations"; + let constructor = "mlir::enzyme::createMathematicSimplificationPass()"; } def AddToOpToIndexAndLoadPass : Pass<"add-to-op-to-index-and-load"> { diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index e93334164dd0..5ced93e86512 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -22,168 +22,284 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/IR/Dominance.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace enzyme; -using llvm::errs; namespace { -// TODO: Expand to region branches?? -bool reachable(Operation *a, Operation *b) { - Block *aBlock = a->getBlock(); - Block *bBlock = b->getBlock(); - if (aBlock == bBlock) { - if (a->isBeforeInBlock(b)) { - return true; - } - } +// Starting at the beginning of blk, is there a path that can execute +// check before end. +bool mayExecuteBefore(Block *blk, Operation *check, Operation *end) { + auto reg = blk->getParent(); + assert(reg->isAncestor(end->getParentRegion())); + DenseSet visitedBlocks; + SmallVector blocksToVisit; + for (auto succ : blk->getSuccessors()) { + blocksToVisit.push_back(succ); + } - blocksToVisit.push_back(aBlock); while (!blocksToVisit.empty()) { - Block *processedBlock = blocksToVisit[blocksToVisit.size() - 1]; - blocksToVisit.pop_back(); + Block *cur = blocksToVisit.pop_back_val(); + + if (visitedBlocks.contains(cur)) + continue; + + visitedBlocks.insert(cur); + + bool seenEnd = false; + for (auto &op : *cur) { + + // If we've seen the thing to check with, it may execute before + if (op.isAncestor(check)) { + // The sole exception to this is if they are in the same sub region, + // which is known to execute only once. TODO this later + /* + if (op.isAncestor(end)) { + + for (auto reg2 : op.getRegions()) { + + } + } + */ - for (Block *successor : processedBlock->getSuccessors()) { - if (!visitedBlocks.contains(successor)) { - visitedBlocks.insert(successor); - blocksToVisit.push_back(successor); + return true; + } - if (successor == bBlock) - return true; + // Otherwise if we've seen the end op, this path is over as the route we + // found here didn't first find a check. + if (op.isAncestor(end)) { + seenEnd = true; + break; } } + + if (seenEnd) + continue; + + // If we didn't find the end, try all successors + for (auto succ : cur->getSuccessors()) { + blocksToVisit.push_back(succ); + } } + return false; } -template -Operation *findNearestDominatingOpByUse(Operation *op, Value v) { +bool mayExecuteBetween(Operation *start, Operation *check, Operation *end) { + + for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) { + // This check op has been found after start in its block + if (op->isAncestor(check)) { + return true; + } + if (op->isAncestor(end)) { + return false; + } + } + + Block *blk = start->getBlock(); + + auto reg = blk->getParent(); + if (reg->isAncestor(end->getParentRegion())) { + return mayExecuteBefore(blk, check, end); + } + + // If the check is in the parent op, but the end is not, assume + // we may execute that parent op part before going to any later ops + if (reg->isAncestor(check->getParentRegion())) { + return true; + } + + return mayExecuteBetween(start->getParentOp(), check, end); +} + +// TODO this isn't necessarily correct. This is because there could be a +// non dominating use bewteen the dominating one and the op, causing +// correctness issues when not seen. In interim, be conservative and only +// succeed if these have the same parent block, and no other ops in path +template +T findNearestDominatingOpByUse(Operation *op, Value v) { DominanceInfo dInfo; + PostDominanceInfo pdInfo; - Operation *closestSetOp = nullptr; + SmallVector options; + SmallVector conflicts; for (Operation *userSet : v.getUsers()) { if (auto setOp = dyn_cast(userSet)) { - if (dInfo.dominates(userSet, op)) { - if (closestSetOp == nullptr) { - closestSetOp = userSet; - } else if (dInfo.dominates(closestSetOp, userSet)) { - closestSetOp = userSet; - } + options.push_back(setOp); + conflicts.push_back(setOp); + continue; + } + if (auto setOp = dyn_cast(userSet)) { + conflicts.push_back(setOp); + continue; + } + } + + for (auto opt : options) { + if (!dInfo.dominates(opt, op)) + continue; + bool conflict = false; + for (auto opt2 : conflicts) { + if (opt == opt2) + continue; + if (opt2 == op) + continue; + + if (!mayExecuteBetween(opt, opt2, op)) { + continue; } + + conflict = true; + } + if (!conflict) { + return opt; } } - return closestSetOp; + + return nullptr; } +struct PopSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PopOp pop, + PatternRewriter &rewriter) const final { + + auto init = pop.getCache().getDefiningOp(); + if (!init) + return failure(); + + SmallVector pops; + SmallVector pushes; + for (Operation *userSet : init.getResult().getUsers()) { + if (auto push = dyn_cast(userSet)) { + pushes.push_back(push); + continue; + } + if (auto pop = dyn_cast(userSet)) { + pops.push_back(pop); + continue; + } + return failure(); + } + + if (auto push = findNearestDominatingOpByUse( + pop, init)) { + // Do the block check to conservatively avoid multi execute push/pop + if (pop->getBlock() == push->getBlock()) { + rewriter.replaceOp(pop, push.getValue()); + rewriter.eraseOp(push); + return success(); + } + } + + return failure(); + } +}; + +struct GetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::GetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + if (isa(userSet)) + continue; + return failure(); + } + + if (auto set = findNearestDominatingOpByUse(get, init)) { + rewriter.replaceOp(get, set.getValue()); + return success(); + } + return failure(); + } +}; + +struct SetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::SetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct PushSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PushOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getCache().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct InitSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::InitOp get, + PatternRewriter &rewriter) const final { + + if (get.use_empty()) { + rewriter.eraseOp(get); + return success(); + } + return failure(); + } +}; + struct RemoveUnusedEnzymeOpsPass : public enzyme::RemoveUnusedEnzymeOpsPassBase { void runOnOperation() override { - getOperation()->walk([&](Operation *op) { - DominanceInfo dInfo; - if (auto initOp = dyn_cast(op)) { - Value v = initOp; - if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userSet : v.getUsers()) { - if (auto setOp = dyn_cast(userSet)) { - for (Operation *userGet : v.getUsers()) { - if (auto getOp = dyn_cast(userGet)) { - // We can safely delete an enzyme.gradient op if each pair of - // enzyme.set and enzyme.get ops are either not reachable or - // are reachable and do not exist inside a loop - bool relatedButNotInLoop = - dInfo.dominates(userSet, userGet) && - !reachable(getOp, setOp); - bool unrelated = !reachable(setOp, getOp); - if (!(relatedButNotInLoop || unrelated)) { - replaceable = false; - } - } - } - } - } - if (replaceable) { - // Do replacing - for (Operation *userGet : v.getUsers()) { - if (auto getOp = dyn_cast(userGet)) { - Operation *closestSetOp = - findNearestDominatingOpByUse(userGet, v); - auto setOp = dyn_cast(closestSetOp); - getOp.replaceAllUsesWith(setOp.getValue()); - } - } - for (Operation *userGet : v.getUsers()) { - userGet->erase(); - } - op->erase(); - } - } else if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userPush : v.getUsers()) { - if (auto pushOp = dyn_cast(userPush)) { - // There should only be exactly one push per pop - if (reachable(userPush, userPush)) { - replaceable = false; - } - int numAssociatedPops = 0; - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Pops always need to be dominated by the push - if (dInfo.dominates(userPush, user)) { - numAssociatedPops++; - } else { - replaceable = false; - } - } - } - if (auto getOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Gets always need to be dominated by the push - if (!dInfo.dominates(userPush, user)) { - replaceable = false; - } - } - } - } - // There should only be one pop per push - if (numAssociatedPops > 1) { - replaceable = false; - } - } - } - if (replaceable) { - // Do replacing - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - popOp.replaceAllUsesWith(pushOp.getValue()); - } - if (auto getOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - getOp.replaceAllUsesWith(pushOp.getValue()); - } - } - for (Operation *user : v.getUsers()) { - user->erase(); - } - op->erase(); - } - } - } - }); - }; + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } }; + } // end anonymous namespace namespace mlir { diff --git a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp b/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp deleted file mode 100644 index d4e374d6c68f..000000000000 --- a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp +++ /dev/null @@ -1,64 +0,0 @@ -//===- ShadowedGradientToCache.cpp - Lower Shadowed Gradient ops -//------------------ // -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a pass to lower custom ops generated by the Enzyme AD -// procedure to the MemRef dialect. -//===----------------------------------------------------------------------===// - -#include "Dialect/Ops.h" -#include "PassDetails.h" -#include "Passes/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace enzyme; -using llvm::errs; -namespace { -struct ShadowedGradientToCachePass - : public enzyme::ShadowedGradientToCachePassBase< - ShadowedGradientToCachePass> { - void runOnOperation() override { - getOperation()->walk([&](Operation *op) { - if (auto initOp = dyn_cast(op)) { - if (auto type = - dyn_cast(initOp.getType())) { - Type cacheType = CacheType::get(op->getContext(), type.getBasetype()); - - OpBuilder builder(op); - Value buffer = builder.create(op->getLoc(), cacheType); - - initOp.replaceAllUsesWith(buffer); - initOp->erase(); - } - } - if (auto clearOp = dyn_cast(op)) { - if (auto type = - dyn_cast(clearOp.getCache().getType())) { - OpBuilder builder(op); - builder.create(op->getLoc(), type.getType(), - clearOp.getCache()); - - clearOp->erase(); - } - } - }); - }; -}; -} // end anonymous namespace - -namespace mlir { -namespace enzyme { -std::unique_ptr createShadowedGradientToCachePass() { - return std::make_unique(); -} -} // namespace enzyme -} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp new file mode 100644 index 000000000000..de04163fea8f --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp @@ -0,0 +1,88 @@ +//===- SimpliyMath.cpp - Simplify Mathematical operations ------------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower custom ops generated by the Enzyme AD +// procedure to the MemRef dialect. +//===----------------------------------------------------------------------===// + +#include "Dialect/Ops.h" +#include "PassDetails.h" +#include "Passes/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace enzyme; +using llvm::errs; +namespace { + +struct AddSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getRhs()); + return success(); + } + + if (matchPattern(op.getRhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + + return failure(); + } +}; + +struct SubSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SubFOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getRhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOpWithNewOp(op, op.getRhs()); + return success(); + } + + return failure(); + } +}; + +struct MathematicSimplification + : public enzyme::MathematicSimplificationPassBase< + MathematicSimplification> { + void runOnOperation() override { + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + }; +}; +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createMathematicSimplificationPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/enzyme/test/MLIR/Passes/dualpush.mlir b/enzyme/test/MLIR/Passes/dualpush.mlir new file mode 100644 index 000000000000..582feddeabff --- /dev/null +++ b/enzyme/test/MLIR/Passes/dualpush.mlir @@ -0,0 +1,48 @@ +// RUN: %eopt -remove-unnecessary-enzyme-ops %s | FileCheck %s + +// This pop cannot be removed even though we know the first popped value with be -1 +// the other pops will be conditional + +module { + func.func private @diffebbargs(%arg0: f64) { + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant 0.000000e+00 : f64 + %3 = "enzyme.init"() : () -> !enzyme.Cache + "enzyme.push"(%3, %c0_i32) : (!enzyme.Cache, i32) -> () + cf.br ^bb1(%arg0 : f64) + ^bb1(%7: f64): // 2 preds: ^bb0, ^bb1 + %8 = arith.cmpf ult, %7, %cst : f64 + "enzyme.push"(%3, %c-1_i32) : (!enzyme.Cache, i32) -> () + cf.cond_br %8, ^bb1(%7 : f64), ^bb4 + ^bb4: // 2 preds: ^bb3, ^bb4 + %18 = "enzyme.pop"(%3) : (!enzyme.Cache) -> i32 + cf.switch %18 : i32, [ + default: ^bb4, + 0: ^bb5 + ] + ^bb5: // pred: ^bb4 + return + } +} + +// CHECK: func.func private @diffebbargs(%arg0: f64) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: "enzyme.push"(%0, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%arg0 : f64) +// CHECK-NEXT: ^bb1(%1: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %2 = arith.cmpf ult, %1, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %2, ^bb1(%1 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: %3 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: cf.switch %3 : i32, [ +// CHECK-NEXT: default: ^bb2, +// CHECK-NEXT: 0: ^bb3 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb3: // pred: ^bb2 +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir index e5bb39eea040..141ff46aaade 100644 --- a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir +++ b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s module { func.func @bbargs(%x: f64) -> f64 { @@ -19,15 +19,38 @@ module { } } -// CHECK: func.func @diff(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { -// CHECK-NEXT: %[[i0:.+]] = call @diffebbargs(%[[arg0]], %[[arg1]]) : (f64, f64) -> f64 -// CHECK-NEXT: return %[[i0:.+]] -// CHECK: func.func private @diffebbargs(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { - -// There should be exactly one block with two f64 args, and their values should be accumulated -// in the shadow. -// CHECK: ^[[BBMULTI:.+]](%[[fst:.+]]: f64, %[[snd:.+]]: f64): -// CHECK-NEXT: "enzyme.set"(%[[shadow:.+]], %[[fst]]) -// CHECK-NEXT: %[[before:.+]] = "enzyme.get"(%[[shadow]]) -// CHECK-NEXT: %[[after:.+]] = arith.addf %[[snd]], %[[before]] -// CHECK-NEXT: "enzyme.set"(%[[shadow]], %[[after]]) +// CHECK: func.func private @diffebbargs(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f64 +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %2 = arith.addf %arg0, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%2 : f64) +// CHECK-NEXT: ^bb1(%3: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %4 = arith.cmpf ult, %3, %cst_0 : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %4, ^bb1(%3 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // pred: ^bb1 +// CHECK-NEXT: %5 = arith.addf %arg1, %cst_0 : f64 +// CHECK-NEXT: %6 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %7 = arith.cmpi eq, %6, %c-1_i32 : i32 +// CHECK-NEXT: %8 = arith.select %7, %5, %cst_0 : f64 +// CHECK-NEXT: %9 = arith.addf %8, %cst_0 : f64 +// CHECK-NEXT: cf.br ^bb3 +// CHECK-NEXT: ^bb3: // 2 preds: ^bb2, ^bb3 +// CHECK-NEXT: %10 = "enzyme.pop"(%1) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %11 = arith.cmpi eq, %10, %c-1_i32 : i32 +// CHECK-NEXT: %12 = arith.select %11, %9, %cst_0 : f64 +// CHECK-NEXT: %13 = arith.addf %12, %cst_0 : f64 +// CHECK-NEXT: cf.switch %10 : i32, [ +// CHECK-NEXT: default: ^bb3, +// CHECK-NEXT: 0: ^bb4 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb4: // pred: ^bb3 +// CHECK-NEXT: %14 = arith.addf %13, %cst_0 : f64 +// CHECK-NEXT: return %14 : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index 5c5596ec389e..9934152def61 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -enzyme-simplify-math -canonicalize %s | FileCheck %s module { func.func @ppow(%x: f64) -> f64 { @@ -19,29 +19,46 @@ module { } } -// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 +// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 { +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %[[one:.+]] = arith.constant 1.0 +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// Make sure the right values are being cached in the primal -// CHECK: %[[one:.+]] = arith.constant 1.0 -// CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -// CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]]) -// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) - -// Ensure the right value is yielded in the adjoint -// CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]]) -// CHECK: %[[dr:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]]) -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]]) -// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) -// CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]]) -// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]]) -// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] -// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) : -// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) -// CHECK-NEXT: %[[dr_next:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: scf.yield %[[dr_next]] -// CHECK: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) -// CHECK-NEXT: return %[[final]] +// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) { +// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[fwd]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 { +// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]] +// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64 +// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) : +// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index +// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) +// CHECK-NEXT: return %[[final]] \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/square.mlir b/enzyme/test/MLIR/ReverseMode/square.mlir new file mode 100644 index 000000000000..4bcae3bb8000 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/square.mlir @@ -0,0 +1,75 @@ +// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN + +module { + func.func @square(%x: f64) -> f64 { + %next = arith.mulf %x, %x : f64 + return %next : f64 + } + + func.func @dsquare(%x: f64, %dr: f64) -> f64 { + %r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme] } : (f64, f64) -> f64 + return %r : f64 + } +} + + +// CHECK: func.func @dsquare(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %0 = call @diffesquare(%arg0, %arg1) : (f64, f64) -> f64 +// CHECK-NEXT: return %0 : f64 +// CHECK-NEXT: } + +// CHECK: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %[[dx:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: %[[c0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[c0]]) : (!enzyme.Gradient, f64) -> () + +// CHECK-NEXT: %[[lhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache + +// CHECK-NEXT: %[[dy:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: %[[c1:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[c1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[rhscache]], %arg0) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[lhscache]], %arg0) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[mul:.+]] = arith.mulf %arg0, %arg0 : f64 +// CHECK-NEXT: cf.br ^bb1 + +// CHECK: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %[[prevdret0:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdret0:.+]] = arith.addf %[[prevdret0]], %arg1 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[postdret0]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[prevdret:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[c2:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[c2]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[postlhs:.+]] = "enzyme.pop"(%[[rhscache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[postrhs:.+]] = "enzyme.pop"(%[[lhscache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dlhs:.+]] = arith.mulf %[[prevdret]], %[[postrhs]] : f64 +// CHECK-NEXT: %[[prevdx1:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdx1:.+]] = arith.addf %[[prevdx1]], %[[dlhs]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[postdx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[drhs:.+]] = arith.mulf %[[prevdret]], %[[postlhs]] : f64 +// CHECK-NEXT: %[[prevdx2:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdx2:.+]] = arith.addf %[[prevdx2]], %[[drhs]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[postdx2]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[res:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: return %[[res]] : f64 +// CHECK-NEXT: } + + +// REM: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// REM-NEXT: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[a1:.+]] = arith.addf %arg1, %[[cst]] : f64 +// REM-NEXT: %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a3:.+]] = arith.addf %[[a2]], %[[cst]] : f64 +// REM-NEXT: %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64 +// REM-NEXT: return %[[a5]] : f64 +// REM-NEXT: } + +// FIN: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// FIN-NEXT: %0 = arith.mulf %arg1, %arg0 : f64 +// FIN-NEXT: %1 = arith.addf %0, %0 : f64 +// FIN-NEXT: return %1 : f64 +// FIN-NEXT: } \ No newline at end of file diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 143c85ea684e..5f456c8755c6 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1811,7 +1811,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, if (hasDiffeRet(resultTree)) { if (intrinsic == MLIRDerivatives) { os << " dif = gutils->diffe(" << origName << ", builder);\n"; - os << " gutils->clearValue(" << origName << ", builder);\n"; + os << " gutils->zeroDiffe(" << origName << ", builder);\n"; } else { os << " dif = diffe(&" << origName << ", Builder2);\n"; os << " setDiffe(&" << origName @@ -2040,6 +2040,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp"); + const auto &retpatterns = recordKeeper.getAllDerivedDefinitions("ReturnOp"); + const auto ®tpatterns = recordKeeper.getAllDerivedDefinitions("RegionTerminatorOp"); @@ -2092,6 +2094,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect << "::" << opName << ">(*context);\n"; } + for (Record *pattern : retpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingReturnInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } for (Record *pattern : allocpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect"); From 456cf5eddf25fda4d25e499e2f621df5f0ac28d1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 1 Mar 2024 12:25:11 -0800 Subject: [PATCH 105/106] [MLIR] Add read-only reverse mode arg (#1774) --- .../MLIR/Implementations/ArithDerivatives.td | 6 + enzyme/Enzyme/MLIR/Implementations/Common.td | 14 +- enzyme/test/MLIR/ForwardMode/trunc.mlir | 18 + enzyme/test/MLIR/ReverseMode/trunc.mlir | 18 + enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 596 ++++++++++-------- 5 files changed, 370 insertions(+), 282 deletions(-) create mode 100644 enzyme/test/MLIR/ForwardMode/trunc.mlir create mode 100644 enzyme/test/MLIR/ReverseMode/trunc.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index 3d53793be3af..eb0294b4d24d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -31,3 +31,9 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y), ], (CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y)) >; + +def ExtF : ArithInst<"ExtFOp">; +def TruncF : ArithInst<"TruncFOp">; + +def : ReadOnlyIdentityOp<"arith", "TruncFOp", [0], (Op $x), [(ExtF (TypeOf $x), (DiffeRet))]>; +def : ReadOnlyIdentityOp<"arith", "ExtFOp", [0], (Op $x), [(TruncF (TypeOf $x), (DiffeRet))]>; diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 3924f4527b00..099e614b8bcd 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -17,14 +17,21 @@ class ControlFlowOp { string impl = impl_; } -class MemoryIdentityOp ptrargs_, list storedargs_ = []> { + +def Unimplemented { + +} + +class MemoryIdentityOp ptrargs_, list storedargs_ = [], dag patternToMatch=(Unimplemented), list reverse_ = []> { string dialect = dialect_; string opName = opName_; + dag PatternToMatch = patternToMatch; list ptrargs = ptrargs_; list storedargs = storedargs_; + list reverse = reverse_; } -class ReadOnlyIdentityOp ptrargs_> : MemoryIdentityOp; +class ReadOnlyIdentityOp ptrargs_, dag patternToMatch=(Unimplemented), list reverse_ = []> : MemoryIdentityOp; class ReturnOp { string dialect = dialect_; @@ -94,6 +101,9 @@ class ConstantFP : Ope def ResultTypes : GlobalExprgetResultTypes()">; +def TypeOf : Operation { +} + class ArithInst : Inst; class MathInst : Inst; diff --git a/enzyme/test/MLIR/ForwardMode/trunc.mlir b/enzyme/test/MLIR/ForwardMode/trunc.mlir new file mode 100644 index 000000000000..8f3918add4ac --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/trunc.mlir @@ -0,0 +1,18 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @f(%x : f64) -> f32 { + %y = arith.truncf %x : f64 to f32 + return %y : f32 + } + func.func @dsq(%x : f64, %dx : f64) -> f32 { + %r = enzyme.fwddiff @f(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f32) + return %r : f32 + } +} + +// CHECK: func.func private @fwddiffef(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f32 { +// CHECK-NEXT: %[[dy:.+]] = arith.truncf %[[arg1]] : f64 to f32 +// CHECK-NEXT: %[[y:.+]] = arith.truncf %[[arg0]] : f64 to f32 +// CHECK-NEXT: return %[[dy]] : f32 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/trunc.mlir b/enzyme/test/MLIR/ReverseMode/trunc.mlir new file mode 100644 index 000000000000..c078cb0634ae --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/trunc.mlir @@ -0,0 +1,18 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN + +module { + func.func @f(%x: f64) -> f32 { + %next = arith.truncf %x : f64 to f32 + return %next : f32 + } + + func.func @dsquare(%x: f64, %dr: f32) -> f64 { + %r = enzyme.autodiff @f(%x, %dr) { activity=[#enzyme] } : (f64, f32) -> f64 + return %r : f64 + } +} + +// FIN: func.func private @diffef(%[[x:.+]]: f64, %[[dx:.+]]: f32) -> f64 { +// FIN-NEXT: %[[res:.+]] = arith.extf %[[dx]] : f32 to f64 +// FIN-NEXT: return %[[res]] : f64 +// FIN-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 5f456c8755c6..6eff540b6622 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -354,7 +354,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("unknown named operand in typeof") + resultTree->getAsString()); - os << "->getType()"; + if (intrinsic == MLIRDerivatives) + os << ".getType()"; + else + os << "->getType()"; return false; } else if (opName == "VectorSize" || Def->isSubClassOf("VectorSize")) { if (resultRoot->getNumArgs() != 1) @@ -1268,6 +1271,298 @@ static void emitHeaderIncludes(const RecordKeeper &recordKeeper, os << "};\n"; } +static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree, + ActionType intrinsic, StringRef origName, + ListInit *argOps) { + + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "RevDerivative : \n"; + os << " public " + "ReverseAutoDiffOpInterface::ExternalModel<" + << opName << "RevDerivative, " << dialect << "::" << opName << "> {\n"; + os << " SmallVector cachedArguments(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " SmallVector toret(op->getNumOperands(), false);\n"; + StringMap> varNameToCondition; + + std::function)> insert = + [&](DagInit *ptree, ArrayRef prev) { + for (auto treeEn : llvm::enumerate(ptree->getArgs())) { + auto tree = treeEn.value(); + auto name = ptree->getArgNameStr(treeEn.index()); + SmallVector next(prev.begin(), prev.end()); + next.push_back(treeEn.index()); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (name.size()) { + varNameToCondition[name] = std::make_tuple( + "idx == " + std::to_string(treeEn.index()), "", false); + } + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + varNameToCondition[tree->getNameStr()] = + std::make_tuple("ILLEGAL", "ILLEGAL", false); + + os << " for (size_t idx=0; idxgetNumOperands(); idx++) {\n"; + os << " bool used = false;\n"; + printDiffUse(os, " ", argOps, origName, intrinsic, tree, + varNameToCondition); + os << " toret[idx] = used;\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + + os << " SmallVector cacheValues(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " if (gutils->isConstantInstruction(op) || " + "gutils->isConstantValue(op->getResult(0))) return {};\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " SmallVector toret;\n"; + os << " OpBuilder builder(gutils->getNewFromOriginal(op));\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " Value cache = " + "gutils->initAndPushCache(gutils->getNewFromOriginal(op->" + "getOperand(en.index())), builder);\n"; + os << " toret.push_back(cache);\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + os << "\n"; + os << " void createShadowValues(Operation *op, OpBuilder &builder,\n"; + os << " MGradientUtilsReverse *gutils) const " + "{}\n"; + + os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " + "&builder,\n"; + os << " MGradientUtilsReverse *gutils,\n"; + os << " SmallVector caches) const {\n"; + os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; + os << " mlir::Value dif = nullptr;\n"; +} + +static VariableSetting parseVariables(DagInit *tree, ActionType intrinsic, + StringRef origName) { + VariableSetting nameToOrdinal; + std::function)> insert = + [&](DagInit *ptree, ArrayRef prev) { + unsigned i = 0; + for (auto tree : ptree->getArgs()) { + SmallVector next(prev.begin(), prev.end()); + next.push_back(i); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (ptree->getArgNameStr(i).size()) { + std::string op; + if (intrinsic != MLIRDerivatives) + op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + else + op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); + if (prev.size() > 0) { + op = "gutils->extractMeta(Builder2, " + op + + ", ArrayRef({"; + bool first = true; + for (unsigned i = 1; i < next.size(); i++) { + if (!first) + op += ", "; + op += std::to_string(next[i]); + } + op += "}))"; + } + nameToOrdinal.insert(ptree->getArgNameStr(i), op, false); + } + i++; + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + nameToOrdinal.insert(tree->getNameStr(), + (Twine("(&") + origName + ")").str(), false); + return nameToOrdinal; +} + +static void emitReverseCommon(raw_ostream &os, Record *pattern, DagInit *tree, + ActionType intrinsic, StringRef origName, + ListInit *argOps) { + auto nameToOrdinal = parseVariables(tree, intrinsic, origName); + + bool seen = false; + for (auto argOpEn : enumerate(*argOps)) { + size_t argIdx = argOpEn.index(); + if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + if (Def->getValueAsBit("asserting")) + os << " assert(gutils->isConstantValue(" << origName << ".getOperand(" + << argIdx << ")));\n"; + continue; + } + } + + os << " "; + if (seen) + os << "} else "; + seen = true; + if (intrinsic == MLIRDerivatives) { + os << "if (!dif && !gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + } else { + os << "if (!dif && !gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + } + DagInit *resultTree = cast(argOpEn.value()); + if (hasDiffeRet(resultTree)) { + if (intrinsic == MLIRDerivatives) { + os << " dif = gutils->diffe(" << origName << ", builder);\n"; + os << " gutils->zeroDiffe(" << origName << ", builder);\n"; + } else { + os << " dif = diffe(&" << origName << ", Builder2);\n"; + os << " setDiffe(&" << origName + << ", " + "Constant::getNullValue(gutils->getShadowType(" + << origName + << ".getType())), " + "Builder2);\n"; + } + } + } + if (seen) + os << " }\n"; + + if (intrinsic == MLIRDerivatives) { + os << " SmallVector operands(op->getNumOperands(), nullptr);\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " size_t count = 0;\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " operands[en.index()] = " + "gutils->popCache(caches[count], builder);\n"; + os << " count++;\n"; + os << " }\n"; + } + + std::function, Init *)> revres = + [&](size_t argIdx, ArrayRef idx, Init *ival) { + if (DagInit *resultTree = dyn_cast(ival)) { + auto Def = cast(resultTree->getOperator())->getDef(); + if (Def->isSubClassOf("MultiReturn")) { + unsigned i = 0; + for (auto r : resultTree->getArgs()) { + SmallVector next(idx.begin(), idx.end()); + next.push_back(i); + revres(argIdx, next, r); + i++; + } + return; + } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } + const char *curIndent = " "; + os << curIndent << "{\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value tmp = "; + else + os << curIndent << INDENT << "Value *tmp = "; + bool vectorValued = handle( + Twine(curIndent) + INDENT, "revarg", os, pattern, resultTree, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", + nameToOrdinal, /*lookup*/ true, idx, origName, + /*newFromOriginal*/ true, intrinsic); + os << ";\n"; + + if (intrinsic == MLIRDerivatives) { + os << "assert(toadd == nullptr); toadd = tmp;\n"; + } else { + os << curIndent << INDENT + << "Value *out = " + "UndefValue::get(gutils->getShadowType(" + << origName << ".getOperand(" << argIdx << ")->getType()));\n"; + + os << curIndent << INDENT + << "for(unsigned int idx=0, W=gutils->getWidth(); " + "idxgetWidth() == " + "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " + "nullptr;\n"; + os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; + if (vectorValued) + os << curIndent << INDENT << INDENT + << "if (gutils->getWidth() > 1) next = " + "gutils->extractMeta(Builder2, next, idx);\n"; + os << curIndent << INDENT << INDENT + << "if (prev) next = Builder2.CreateFAdd(prev, " + "next);\n"; + os << curIndent << INDENT << INDENT + << "out = (gutils->getWidth() > 1) ? " + "Builder2.CreateInsertValue(out, next, idx) : next;\n"; + os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "toadd = out;\n"; + } + os << curIndent << "}\n"; + + } else if (ListInit *lst = dyn_cast(ival)) { + unsigned i = 0; + for (auto elem : *lst) { + SmallVector next(idx.begin(), idx.end()); + next.push_back(i); + revres(argIdx, next, elem); + i++; + } + } else + assert(0); + }; + + for (auto argOpEn : enumerate(*argOps)) { + size_t argIdx = argOpEn.index(); + if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + continue; + } + } + + const char *curIndent = " "; + if (intrinsic == MLIRDerivatives) + os << curIndent << "if (!gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + else + os << curIndent << "if (!gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value toadd = nullptr;\n"; + else + os << curIndent << INDENT << "Value *toadd = nullptr;\n"; + revres(argIdx, {}, argOpEn.value()); + + if (intrinsic == MLIRDerivatives) { + os << curIndent << INDENT << "if (toadd) gutils->addToDiffe(" << origName + << "->getOperand(" << argIdx << "), toadd, builder);\n"; + } else { + os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName + << ".getOperand(" << argIdx << "), toadd"; + os << ", Builder2, " << origName << ".getOperand(" << argIdx + << ")->getType());\n"; + } + os << curIndent << "}\n"; + } +} static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); @@ -1467,45 +1762,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } } - VariableSetting nameToOrdinal; - - std::function)> insert = - [&](DagInit *ptree, ArrayRef prev) { - unsigned i = 0; - for (auto tree : ptree->getArgs()) { - SmallVector next(prev.begin(), prev.end()); - next.push_back(i); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (ptree->getArgNameStr(i).size()) { - std::string op; - if (intrinsic != MLIRDerivatives) - op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); - else - op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); - if (prev.size() > 0) { - op = "gutils->extractMeta(Builder2, " + op + - ", ArrayRef({"; - bool first = true; - for (unsigned i = 1; i < next.size(); i++) { - if (!first) - op += ", "; - op += std::to_string(next[i]); - } - op += "}))"; - } - nameToOrdinal.insert(ptree->getArgNameStr(i), op, false); - } - i++; - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - nameToOrdinal.insert(tree->getNameStr(), - (Twine("(&") + origName + ")").str(), false); + VariableSetting nameToOrdinal = parseVariables(tree, intrinsic, origName); if (intrinsic != BinopDerivatives && intrinsic != InstDerivatives && intrinsic != MLIRDerivatives) { @@ -1706,248 +1963,10 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " Value *dif = nullptr;\n"; } else { os << "};\n"; - auto opName = pattern->getValueAsString("opName"); - auto dialect = pattern->getValueAsString("dialect"); - os << "struct " << opName << "RevDerivative : \n"; - os << " public " - "ReverseAutoDiffOpInterface::ExternalModel<" - << opName << "RevDerivative, " << dialect << "::" << opName << "> {\n"; - os << " SmallVector cachedArguments(Operation *op,\n"; - os << " MGradientUtilsReverse *gutils) " - "const {\n"; - os << " SmallVector toret(op->getNumOperands(), false);\n"; - StringMap> varNameToCondition; - - std::function)> insert = - [&](DagInit *ptree, ArrayRef prev) { - for (auto treeEn : llvm::enumerate(ptree->getArgs())) { - auto tree = treeEn.value(); - auto name = ptree->getArgNameStr(treeEn.index()); - SmallVector next(prev.begin(), prev.end()); - next.push_back(treeEn.index()); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (name.size()) { - varNameToCondition[name] = std::make_tuple( - "idx == " + std::to_string(treeEn.index()), "", false); - } - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - varNameToCondition[tree->getNameStr()] = - std::make_tuple("ILLEGAL", "ILLEGAL", false); - - os << " for (size_t idx=0; idxgetNumOperands(); idx++) {\n"; - os << " bool used = false;\n"; - printDiffUse(os, " ", argOps, origName, intrinsic, tree, - varNameToCondition); - os << " toret[idx] = used;\n"; - os << " }\n"; - os << " return toret;\n"; - os << " }\n"; - - os << " SmallVector cacheValues(Operation *op,\n"; - os << " MGradientUtilsReverse *gutils) " - "const {\n"; - os << " if (gutils->isConstantInstruction(op) || " - "gutils->isConstantValue(op->getResult(0))) return {};\n"; - os << " auto neededArgs = cachedArguments(op, gutils);\n"; - os << " SmallVector toret;\n"; - os << " OpBuilder builder(gutils->getNewFromOriginal(op));\n"; - os << " for (auto en : llvm::enumerate(neededArgs))\n"; - os << " if (en.value()) {\n"; - os << " Value cache = " - "gutils->initAndPushCache(gutils->getNewFromOriginal(op->" - "getOperand(en.index())), builder);\n"; - os << " toret.push_back(cache);\n"; - os << " }\n"; - os << " return toret;\n"; - os << " }\n"; - os << "\n"; - os << " void createShadowValues(Operation *op, OpBuilder &builder,\n"; - os << " MGradientUtilsReverse *gutils) const " - "{}\n"; - - os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " - "&builder,\n"; - os << " MGradientUtilsReverse *gutils,\n"; - os << " SmallVector caches) const {\n"; - os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; - os << " mlir::Value dif = nullptr;\n"; - } - // TODO vector - - bool seen = false; - for (auto argOpEn : enumerate(*argOps)) { - size_t argIdx = argOpEn.index(); - if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { - auto opName = resultRoot->getOperator()->getAsString(); - auto Def = cast(resultRoot->getOperator())->getDef(); - if (opName == "InactiveArgSpec" || - Def->isSubClassOf("InactiveArgSpec")) { - if (Def->getValueAsBit("asserting")) - os << " assert(gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << ")));\n"; - continue; - } - } - - os << " "; - if (seen) - os << "} else "; - seen = true; - if (intrinsic == MLIRDerivatives) { - os << "if (!dif && !gutils->isConstantValue(" << origName - << "->getOperand(" << argIdx << "))) {\n"; - } else { - os << "if (!dif && !gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; - } - DagInit *resultTree = cast(argOpEn.value()); - if (hasDiffeRet(resultTree)) { - if (intrinsic == MLIRDerivatives) { - os << " dif = gutils->diffe(" << origName << ", builder);\n"; - os << " gutils->zeroDiffe(" << origName << ", builder);\n"; - } else { - os << " dif = diffe(&" << origName << ", Builder2);\n"; - os << " setDiffe(&" << origName - << ", " - "Constant::getNullValue(gutils->getShadowType(" - << origName - << ".getType())), " - "Builder2);\n"; - } - } + emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); } - if (seen) - os << " }\n"; - - if (intrinsic == MLIRDerivatives) { - os << " SmallVector operands(op->getNumOperands(), nullptr);\n"; - os << " auto neededArgs = cachedArguments(op, gutils);\n"; - os << " size_t count = 0;\n"; - os << " for (auto en : llvm::enumerate(neededArgs))\n"; - os << " if (en.value()) {\n"; - os << " operands[en.index()] = " - "gutils->popCache(caches[count], builder);\n"; - os << " count++;\n"; - os << " }\n"; - } - - std::function, Init *)> revres = - [&](size_t argIdx, ArrayRef idx, Init *ival) { - if (DagInit *resultTree = dyn_cast(ival)) { - auto Def = cast(resultTree->getOperator())->getDef(); - if (Def->isSubClassOf("MultiReturn")) { - unsigned i = 0; - for (auto r : resultTree->getArgs()) { - SmallVector next(idx.begin(), idx.end()); - next.push_back(i); - revres(argIdx, next, r); - i++; - } - return; - } - if (Def->isSubClassOf("InactiveArgSpec")) { - return; - } - const char *curIndent = " "; - os << curIndent << "{\n"; - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value tmp = "; - else - os << curIndent << INDENT << "Value *tmp = "; - bool vectorValued = handle( - Twine(curIndent) + INDENT, "revarg", os, pattern, resultTree, - (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", - nameToOrdinal, /*lookup*/ true, idx, origName, - /*newFromOriginal*/ true, intrinsic); - os << ";\n"; - - if (intrinsic == MLIRDerivatives) { - os << "assert(toadd == nullptr); toadd = tmp;\n"; - } else { - os << curIndent << INDENT - << "Value *out = " - "UndefValue::get(gutils->getShadowType(" - << origName << ".getOperand(" << argIdx << ")->getType()));\n"; - - os << curIndent << INDENT - << "for(unsigned int idx=0, W=gutils->getWidth(); " - "idxgetWidth() == " - "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " - "nullptr;\n"; - os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; - if (vectorValued) - os << curIndent << INDENT << INDENT - << "if (gutils->getWidth() > 1) next = " - "gutils->extractMeta(Builder2, next, idx);\n"; - os << curIndent << INDENT << INDENT - << "if (prev) next = Builder2.CreateFAdd(prev, " - "next);\n"; - os << curIndent << INDENT << INDENT - << "out = (gutils->getWidth() > 1) ? " - "Builder2.CreateInsertValue(out, next, idx) : next;\n"; - os << curIndent << INDENT << "}\n"; - os << curIndent << INDENT << "toadd = out;\n"; - } - os << curIndent << "}\n"; - - } else if (ListInit *lst = dyn_cast(ival)) { - unsigned i = 0; - for (auto elem : *lst) { - SmallVector next(idx.begin(), idx.end()); - next.push_back(i); - revres(argIdx, next, elem); - i++; - } - } else - assert(0); - }; - for (auto argOpEn : enumerate(*argOps)) { - size_t argIdx = argOpEn.index(); - if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { - auto opName = resultRoot->getOperator()->getAsString(); - auto Def = cast(resultRoot->getOperator())->getDef(); - if (opName == "InactiveArgSpec" || - Def->isSubClassOf("InactiveArgSpec")) { - continue; - } - } - - const char *curIndent = " "; - if (intrinsic == MLIRDerivatives) - os << curIndent << "if (!gutils->isConstantValue(" << origName - << "->getOperand(" << argIdx << "))) {\n"; - else - os << curIndent << "if (!gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; - initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value toadd = nullptr;\n"; - else - os << curIndent << INDENT << "Value *toadd = nullptr;\n"; - revres(argIdx, {}, argOpEn.value()); - - if (intrinsic == MLIRDerivatives) { - os << curIndent << INDENT << "if (toadd) gutils->addToDiffe(" - << origName << "->getOperand(" << argIdx << "), toadd, builder);\n"; - } else { - os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName - << ".getOperand(" << argIdx << "), toadd"; - os << ", Builder2, " << origName << ".getOperand(" << argIdx - << ")->getType());\n"; - } - os << curIndent << "}\n"; - } + emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); if (intrinsic != MLIRDerivatives) { os << " auto found = gutils->invertedPointers.find(&(" << origName @@ -2036,6 +2055,18 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } os << " return true;\n }\n"; os << "};\n"; + + DagInit *tree = pattern->getValueAsDag("PatternToMatch"); + + if (tree->getOperator()->getAsString() != "Unimplemented") { + ListInit *argOps = pattern->getValueAsListInit("reverse"); + auto origName = "op"; + emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); + emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); + os << " return;\n"; + os << " }\n"; + os << " };\n"; + } } const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp"); @@ -2081,6 +2112,11 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, for (auto storedarg : pattern->getValueAsListOfInts("storedargs")) os << ", " << storedarg; os << ">(*context);\n"; + DagInit *tree = pattern->getValueAsDag("PatternToMatch"); + if (tree->getOperator()->getAsString() != "Unimplemented") { + os << " " << dialect << "::" << opName << "::attachInterface<" + << opName << "RevDerivative>(*context);\n"; + } } for (Record *pattern : brpatterns) { auto opName = pattern->getValueAsString("opName"); From 0b621884bc531329095d202f042f6599a86614ec Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 1 Mar 2024 18:02:00 -0800 Subject: [PATCH 106/106] [MLIR] Fix reverse wrap pass infra (#1775) --- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp | 1 - enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 7 +- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 80 +------------------ .../MLIR/Interfaces/GradientUtilsReverse.cpp | 16 ++-- .../MLIR/Interfaces/GradientUtilsReverse.h | 8 +- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 2 +- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 21 +++-- 7 files changed, 30 insertions(+), 105 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index 7d4ccead7764..ead8baad9261 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -5,7 +5,6 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionInterfaces.h" // TODO: this shouldn't depend on specific dialects except Enzyme. diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index dc98a5c4c6a6..56d49bf79b09 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -119,8 +119,7 @@ class MEnzymeLogic { std::vector constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, - std::vector volatile_args, void *augmented, - SymbolTableCollection &symbolTable); + std::vector volatile_args, void *augmented); void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); @@ -132,8 +131,6 @@ class MEnzymeLogic { MGradientUtilsReverse *gutils); void visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); - bool visitChildCustom(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils); void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); SmallVector getDominatorToposort(MGradientUtilsReverse *gutils, @@ -145,4 +142,4 @@ class MEnzymeLogic { }; } // Namespace enzyme -} // Namespace mlir \ No newline at end of file +} // Namespace mlir diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index f39766a0a639..25e8f1818cd2 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -35,76 +35,6 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, } } -bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) { - std::string nameDiffe = "diffe_" + op->getName().getDialectNamespace().str() + - "_" + op->getName().stripDialect().str(); - std::string nameStore = "store_" + op->getName().getDialectNamespace().str() + - "_" + op->getName().stripDialect().str(); - - StringRef srDiffe(nameDiffe); - StringRef srStore(nameStore); - - OperationName opNameDiffe(srDiffe, op->getContext()); - OperationName opNameStore(srStore, op->getContext()); - - Operation *symbolDiffe = gutils->symbolTable.lookupNearestSymbolFrom( - op, opNameDiffe.getIdentifier()); - Operation *symbolStore = gutils->symbolTable.lookupNearestSymbolFrom( - op, opNameStore.getIdentifier()); - - if (symbolDiffe != nullptr) { - SmallVector caches; - if (symbolStore != nullptr) { - Operation *newOp = gutils->getNewFromOriginal(op); - - func::FuncOp funcStore = cast(symbolStore); - - SmallVector storeResultTypes; - for (auto x : funcStore.getFunctionType().getResults()) { - storeResultTypes.push_back(x); - } - - SmallVector storeArgs; - for (auto x : newOp->getOperands()) { - storeArgs.push_back(x); - } - - OpBuilder storeBuilder(newOp); - func::CallOp storeCI = storeBuilder.create( - op->getLoc(), srStore, storeResultTypes, storeArgs); - for (auto x : storeCI.getResults()) { - caches.push_back(gutils->initAndPushCache(x, storeBuilder)); - } - } - - SmallVector args; - for (Value opResult : op->getResults()) { - if (!gutils->isConstantValue(opResult)) { - Value invertValue = gutils->invertPointerM(opResult, builder); - args.push_back(invertValue); - } - } - for (Value cache : caches) { - args.push_back(gutils->popCache(cache, builder)); - } - - SmallVector resultTypes; - for (auto x : op->getOperands()) { - resultTypes.push_back(x.getType()); - } - - func::CallOp dCI = - builder.create(op->getLoc(), srDiffe, resultTypes, args); - for (int i = 0; i < (int)op->getNumOperands(); i++) { - gutils->setDiffe(op->getOperand(i), dCI.getResult(i), builder); - } - - return true; - } - return false; -} - /* Create reverse mode adjoint for an operation. */ @@ -139,10 +69,7 @@ void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, auto last = oBB->rend(); for (auto it = first; it != last; ++it) { Operation *op = &*it; - bool customFound = visitChildCustom(op, revBuilder, gutils); - if (!customFound) { - visitChild(op, revBuilder, gutils); - } + visitChild(op, revBuilder, gutils); } } } @@ -257,8 +184,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( FunctionOpInterface fn, DIFFE_TYPE retType, std::vector constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented, - SymbolTableCollection &symbolTable) { + MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; @@ -268,7 +194,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( ReturnType returnValue = ReturnType::Args; MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( *this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, - constants, returnValue, addedType, symbolTable); + constants, returnValue, addedType); Region &oldRegion = gutils->oldFunc.getFunctionBody(); Region &newRegion = gutils->newFunc.getFunctionBody(); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 20d34247fb3a..b57fbe68b594 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -35,12 +35,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( const SmallPtrSetImpl &activevals_, DIFFE_TYPE ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_) + DerivativeMode mode_, unsigned width) : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, invertedPointers_, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, - originalToNewFnOps_, mode_, width, /*omp*/ false), - symbolTable(symbolTable_) {} + originalToNewFnOps_, mode_, width, /*omp*/ false) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -135,8 +134,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg, - SymbolTableCollection &symbolTable_) { + ReturnType returnValue, mlir::Type additionalArg) { std::string prefix; switch (mode_) { @@ -168,8 +166,8 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( prefix + todiff.getName(), originalToNew, originalToNewOps, diffeReturnArg, additionalArg); - return new MGradientUtilsReverse( - Logic, newFunc, todiff, TA, invertedPointers, constant_values, - nonconstant_values, retType, constant_args, originalToNew, - originalToNewOps, mode_, width, symbolTable_); + return new MGradientUtilsReverse(Logic, newFunc, todiff, TA, invertedPointers, + constant_values, nonconstant_values, retType, + constant_args, originalToNew, + originalToNewOps, mode_, width); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index d3b2e818391f..96e899939538 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -31,13 +31,10 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width, - SymbolTableCollection &symbolTable_); + DerivativeMode mode_, unsigned width); IRMapping mapReverseModeBlocks; - SymbolTableCollection &symbolTable; - void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder); @@ -67,8 +64,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg, - SymbolTableCollection &symbolTable_); + ReturnType returnValue, mlir::Type additionalArg); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index de7328dcb186..b7d33b6faedc 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -153,7 +153,7 @@ struct DifferentiatePass : public DifferentiatePassBase { fn, retType, constants, TA, /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr, symbolTable); + /*augmented*/ nullptr); if (!newFunc) return failure(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index b8d91f2c82e8..b48705c220d1 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -91,12 +91,21 @@ struct DifferentiateWrapperPass volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); } - FunctionOpInterface newFunc = Logic.CreateForwardDiff( - fn, retType, constants, TA, - /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, - width, - /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + FunctionOpInterface newFunc; + if (mode == DerivativeMode::ForwardMode) { + newFunc = Logic.CreateForwardDiff( + fn, retType, constants, TA, + /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, + width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + } else { + newFunc = Logic.CreateReverseDiff( + fn, retType, constants, TA, + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + } if (!newFunc) { signalPassFailure(); return;