diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 368b7933e347..a9b75aefae89 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,8 +1,8 @@ // available ubuntu versions: [20, 22] -// available llvm versions: [9, 10, 11, 12, 13, 14, 15] +// available llvm versions: [11, 12, 13, 14, 15, 16, 17, 18] { "name": "Enzyme", - "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-20-llvm-12:latest", + "image": "ghcr.io/enzymead/enzyme-dev-docker/ubuntu-22-llvm-16:latest", "mounts": [ "source=enzyme-bashhistory,target=/commandhistory,type=volume", "source=enzyme-extensions,target=/home/vscode/.vscode-server/extensions,type=volume", @@ -14,7 +14,9 @@ "customizations": { "vscode": { "extensions": [ - "ms-vscode.cpptools-extension-pack" + "llvm-vs-code-extensions.vscode-clangd", + "BazelBuild.vscode-bazel", + "twxs.cmake" ] } } diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..edb6037854cc --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: +# Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/bcload.yml b/.github/workflows/bcload.yml index 33e409fd8bfe..9cbbdb7094ff 100644 --- a/.github/workflows/bcload.yml +++ b/.github/workflows/bcload.yml @@ -27,7 +27,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: cd enzyme && rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 46fa9edb1bed..d859468708b5 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -37,7 +37,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/ccpp.yml b/.github/workflows/ccpp.yml index 46c64d5b369e..1b5b293a24b7 100644 --- a/.github/workflows/ccpp.yml +++ b/.github/workflows/ccpp.yml @@ -27,6 +27,7 @@ jobs: - name: add llvm run: | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-get install -y libmpfr-dev sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev lld-${{ matrix.llvm }} clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev libzstd-dev sudo python3 -m pip install --upgrade pip lit @@ -34,7 +35,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/doxygen.yml b/.github/workflows/doxygen.yml index 659ed659589a..12e64dc5d8ce 100644 --- a/.github/workflows/doxygen.yml +++ b/.github/workflows/doxygen.yml @@ -9,9 +9,9 @@ jobs: docs: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - - uses: mattnotmitt/doxygen-action@v1.9.2 + - uses: mattnotmitt/doxygen-action@v1.9.8 with: working-directory: 'enzyme/' doxyfile-path: 'doxygen.cfg' diff --git a/.github/workflows/enzyme-ci.yml b/.github/workflows/enzyme-ci.yml index 4a1ef073160d..5f8e9ea6e524 100644 --- a/.github/workflows/enzyme-ci.yml +++ b/.github/workflows/enzyme-ci.yml @@ -32,7 +32,7 @@ jobs: sudo sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake sudo sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake fi - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake @@ -61,7 +61,7 @@ jobs: strategy: fail-fast: false matrix: - llvm: ["11", "12", "13", "14", "15"] + llvm: ["12", "13", "14", "15", "16"] build: ["Release", "Debug"] # "RelWithDebInfo" timeout-minutes: 30 @@ -71,7 +71,7 @@ jobs: brew update brew install llvm@${{ matrix.llvm }} make cmake sudo python3 -m pip install --upgrade pip lit requests - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake @@ -109,7 +109,7 @@ jobs: run: | brew install llvm@${{ matrix.llvm }} make cmake gcc sudo python3 -m pip install --upgrade pip lit - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: mkdir run: rm -rf build && mkdir build - name: cmake diff --git a/.github/workflows/enzyme-julia.yml b/.github/workflows/enzyme-julia.yml index fcb6ff659a07..4aabf7204f5b 100644 --- a/.github/workflows/enzyme-julia.yml +++ b/.github/workflows/enzyme-julia.yml @@ -28,8 +28,8 @@ jobs: - x64 timeout-minutes: 60 steps: - - uses: actions/checkout@v3 - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - uses: actions/checkout@v4 with: repository: 'wsmoses/Enzyme.jl' path: ./jl diff --git a/.github/workflows/enzyme-mlir.yml b/.github/workflows/enzyme-mlir.yml index b4c581a3a720..34a7828af856 100644 --- a/.github/workflows/enzyme-mlir.yml +++ b/.github/workflows/enzyme-mlir.yml @@ -29,14 +29,14 @@ jobs: sudo apt-get update sudo apt-get install -y binutils ninja-build cmake gcc g++ python3 python3-dev - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: path: 'Enzyme' - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: 'llvm/llvm-project' - ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb' + ref: '2c9b6c1b36b8185299de083c3058e0c1e7760442' path: 'llvm-project' - name: Get MLIR commit hash @@ -46,7 +46,7 @@ jobs: - name: Cache MLIR id: cache-mlir - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: llvm-project/mlir-build key: ${{ matrix.llbuild }}-${{ matrix.os }}-mlir-${{ steps.mlir-commit.outputs.sha_short }} diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 55b2c99749be..c01757d0ff86 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: DoozyX/clang-format-lint-action@v0.16.2 with: source: 'enzyme/Enzyme enzyme/tools/enzyme-tblgen' diff --git a/.github/workflows/fortran.yml b/.github/workflows/fortran.yml index a31294c90b4f..e5e960345f63 100644 --- a/.github/workflows/fortran.yml +++ b/.github/workflows/fortran.yml @@ -40,7 +40,7 @@ jobs: sudo apt-get update && sudo apt-get install -y intel-oneapi-compiler-fortran-${{ matrix.ifx }} intel-oneapi-mpi-${{ matrix.mpi }} intel-oneapi-mpi-devel-${{ matrix.mpi }} source /opt/intel/oneapi/setvars.sh printenv >> $GITHUB_ENV - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: generate build system run: | rm -rf build && mkdir build && cd build diff --git a/.github/workflows/tagger.yml b/.github/workflows/tagger.yml index 9d712fb0cd6c..46f8874e92e3 100644 --- a/.github/workflows/tagger.yml +++ b/.github/workflows/tagger.yml @@ -10,19 +10,19 @@ jobs: name: Enzyme Tag CI runs-on: ubuntu-latest steps: - - uses: tibdex/github-app-token@v1 + - uses: actions/create-github-app-token@v1 id: generate_token with: app_id: ${{ secrets.APP_ID }} private_key: ${{ secrets.APP_PRIVATE_KEY }} - repository: JuliaPackaging/Yggdrasil + repositories: JuliaPackaging/Yggdrasil - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: repository: 'JuliaPackaging/Yggdrasil' path: ygg - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: path: enz - name: replace @@ -37,7 +37,7 @@ jobs: git add . - name: Create Pull Request id: cpr - uses: peter-evans/create-pull-request@v3 + uses: peter-evans/create-pull-request@v6 with: path: ygg commit-message: "Upgrade enzyme to ${{ github.ref }}" diff --git a/.gitignore b/.gitignore index aad84254cf78..33a6c3f8cdd8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ enzyme/benchmarks/ReverseMode/*/*.o enzyme/benchmarks/ReverseMode/*/*.exe enzyme/benchmarks/ReverseMode/*/results.txt enzyme/benchmarks/ReverseMode/*/results.json +.cache diff --git a/.packaging/build_tarballs.jl b/.packaging/build_tarballs.jl index 67d214bf2c86..0b41f026ba38 100644 --- a/.packaging/build_tarballs.jl +++ b/.packaging/build_tarballs.jl @@ -28,7 +28,7 @@ platforms = expand_cxxstring_abis(supported_platforms(; experimental=true)) script = raw""" cd Enzyme -if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15.asserts* ]]; then +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16* ]]; then # LLVM 15 requires macOS SDK 10.14. pushd $WORKSPACE/srcdir/MacOSX10.*.sdk rm -rf /opt/${target}/${target}/sys-root/System @@ -54,11 +54,11 @@ cmake -B build-native -S enzyme -GNinja "${NATIVE_CMAKE_FLAGS[@]}" # Only build blasheaders and tblgen ninja -C build-native -j ${nproc} blasheaders enzyme-tblgen - # 2. Cross-compile CMAKE_FLAGS=() CMAKE_FLAGS+=(-DENZYME_EXTERNAL_SHARED_LIB=ON) CMAKE_FLAGS+=(-DBC_LOAD_HEADER=`pwd`/build-native/BCLoad/gsl/blas_headers.h) +CMAKE_FLAGS+=(-DEnzyme_TABLEGEN=`pwd`/build-native/tools/enzyme-tblgen/enzyme-tblgen) CMAKE_FLAGS+=(-DEnzyme_TABLEGEN_EXE=`pwd`/build-native/tools/enzyme-tblgen/enzyme-tblgen) CMAKE_FLAGS+=(-DENZYME_CLANG=OFF) # RelWithDebInfo for decent performance, with debugability @@ -66,7 +66,11 @@ CMAKE_FLAGS+=(-DCMAKE_BUILD_TYPE=RelWithDebInfo) # Install things into $prefix CMAKE_FLAGS+=(-DCMAKE_INSTALL_PREFIX=${prefix}) # Explicitly use our cmake toolchain file and tell CMake we're cross-compiling -CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN}) +if [[ "${target}" == *mingw* && "${bb_full_target}" == *llvm_version+16* ]]; then + CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN%.*}_clang.cmake) +else + CMAKE_FLAGS+=(-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TARGET_TOOLCHAIN}) +fi CMAKE_FLAGS+=(-DCMAKE_CROSSCOMPILING:BOOL=ON) # Tell CMake where LLVM is CMAKE_FLAGS+=(-DLLVM_DIR="${prefix}/lib/cmake/llvm") @@ -74,10 +78,18 @@ CMAKE_FLAGS+=(-DLLVM_DIR="${prefix}/lib/cmake/llvm") CMAKE_FLAGS+=(-DLLVM_LINK_LLVM_DYLIB=ON) # Build the library CMAKE_FLAGS+=(-DBUILD_SHARED_LIBS=ON) + +if [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+15* ]] || [[ "${bb_full_target}" == x86_64-apple-darwin*llvm_version+16* ]]; then +if [[ "${target}" == x86_64-apple* ]]; then + CMAKE_FLAGS+=(-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=10.14) +fi +else if [[ "${target}" == x86_64-apple* ]]; then CMAKE_FLAGS+=(-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=10.12) fi +fi +echo ${CMAKE_FLAGS[@]} cmake -B build -S enzyme -GNinja ${CMAKE_FLAGS[@]} ninja -C build -j ${nproc} install @@ -113,11 +125,11 @@ for llvm_version in llvm_versions, llvm_assertions in (false, true) # We don't build LLVM 15 for i686-linux-musl. filter!(p -> !(arch(p) == "i686" && libc(p) == "musl"), platforms) end - + for platform in platforms augmented_platform = deepcopy(platform) augmented_platform[LLVM.platform_name] = LLVM.platform(llvm_version, llvm_assertions) - gcc_version = version > v"15" ? v"10" : v"8" + gcc_version = llvm_version > v"15" ? v"10" : v"8" should_build_platform(triplet(augmented_platform)) || continue push!(builds, (; dependencies, products, diff --git a/enzyme/BCLoad/BCLoader.cpp b/enzyme/BCLoad/BCLoader.cpp index 6aad07eb6739..286bfc2e421e 100644 --- a/enzyme/BCLoad/BCLoader.cpp +++ b/enzyme/BCLoad/BCLoader.cpp @@ -24,28 +24,30 @@ static inline bool endsWith(llvm::StringRef string, llvm::StringRef suffix) { #endif // LLVM_VERSION_MAJOR } -bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { +bool provideDefinitions(Module &M, std::set ignoreFunctions, + std::vector &replaced) { std::vector todo; bool seen32 = false; bool seen64 = false; for (auto &F : M) { if (!F.empty()) continue; + if (ignoreFunctions.count(F.getName().str())) + continue; int index = 0; for (auto postfix : {"", "_", "_64_"}) { std::string str; if (strlen(postfix) == 0) { str = F.getName().str(); - if (ignoreFunctions.count(str)) continue; } else if (endsWith(F.getName(), postfix)) { auto blasName = F.getName().substr(0, F.getName().size() - strlen(postfix)).str(); - if (ignoreFunctions.count(blasName)) continue; str = "cblas_" + blasName; } auto found = EnzymeBlasBC.find(str); if (found != EnzymeBlasBC.end()) { + replaced.push_back(F.getName().str()); todo.push_back(found->second); if (index == 1) seen32 = true; @@ -81,13 +83,24 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { }); #endif - if (!BC) + if (!BC) { Err.print("bcloader", llvm::errs()); + continue; + } assert(BC); SmallVector toReplace; for (auto &F : *BC) { if (F.empty()) continue; + if (ignoreFunctions.count(F.getName().str())) { + F.dropAllReferences(); +#if LLVM_VERSION_MAJOR >= 16 + F.erase(F.begin(), F.end()); +#else + F.getBasicBlockList().erase(F.begin(), F.end()); +#endif + continue; + } toReplace.push_back(F.getName().str()); } BC->setTargetTriple(""); @@ -106,12 +119,29 @@ bool provideDefinitions(Module &M, std::set ignoreFunctions = {}) { extern "C" { uint8_t EnzymeBitcodeReplacement(LLVMModuleRef M, char **FncsNamesToIgnore, - size_t numFncNames) { + size_t numFncNames, const char ***foundP, + size_t *foundLen) { std::set ignoreFunctions = {}; for (size_t i = 0; i < numFncNames; i++) { ignoreFunctions.insert(std::string(FncsNamesToIgnore[i])); } - return provideDefinitions(*unwrap(M), ignoreFunctions); + std::vector replaced; + auto res = provideDefinitions(*unwrap(M), ignoreFunctions, replaced); + + const char **found = nullptr; + if (replaced.size()) { + found = (const char **)malloc(replaced.size() * sizeof(const char **)); + for (size_t i = 0; i < replaced.size(); i++) { + char *data = (char *)malloc(replaced[i].size() + 1); + memcpy(data, replaced[i].data(), replaced[i].size()); + data[replaced[i].size()] = 0; + found[i] = data; + } + } + *foundP = found; + *foundLen = replaced.size(); + + return res; } } @@ -121,7 +151,10 @@ class BCLoader final : public ModulePass { static char ID; BCLoader() : ModulePass(ID) {} - bool runOnModule(Module &M) override { return provideDefinitions(M, {}); } + bool runOnModule(Module &M) override { + std::vector replaced; + return provideDefinitions(M, {}, replaced); + } }; } // namespace diff --git a/enzyme/BCLoad/CMakeLists.txt b/enzyme/BCLoad/CMakeLists.txt index f2b04238fb67..f88f515ec77f 100644 --- a/enzyme/BCLoad/CMakeLists.txt +++ b/enzyme/BCLoad/CMakeLists.txt @@ -4,10 +4,10 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(BC_LOAD_FLAGS "" CACHE STRING "") set(BC_LOAD_HEADER "" CACHE STRING "") -if (${LLVM_VERSION_MAJOR} LESS 15) - set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") -else() +if (${LLVM_VERSION_MAJOR} EQUAL 15 OR ${LLVM_VERSION_MAJOR} EQUAL 16) set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS} -Xclang -no-opaque-pointers") +else() + set(BC_LOAD_FLAGS2 "${BC_LOAD_FLAGS}") endif() if (APPLE) diff --git a/enzyme/BUILD b/enzyme/BUILD index c9ae1c3cdb71..bfb41ed3f9ff 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -1,13 +1,11 @@ load("@llvm-project//llvm:tblgen.bzl", "gentbl") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") -load("@bazel_skylib//rules:expand_template.bzl", "expand_template") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//visibility:public"], ) cc_library( @@ -29,6 +27,7 @@ cc_binary( "@llvm-project//llvm:TableGen", "@llvm-project//llvm:config", ], + visibility = ["//visibility:public"], ) gentbl( @@ -143,13 +142,30 @@ gentbl( ], ) +gentbl( + name = "include-utils", + tbl_outs = [( + "-gen-header-strings", + "IncludeUtils.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/Clang/include_utils.td", + td_srcs = ["Enzyme/Clang/include_utils.td"], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeStatic", - srcs = glob([ - "Enzyme/*.cpp", - "Enzyme/TypeAnalysis/*.cpp", - "Enzyme/Clang/EnzymeClang.cpp", - ], exclude=["Enzyme/eopt.cpp"]), + srcs = glob( + [ + "Enzyme/*.cpp", + "Enzyme/TypeAnalysis/*.cpp", + "Enzyme/Clang/EnzymeClang.cpp", + ], + exclude = ["Enzyme/eopt.cpp"], + ), hdrs = glob([ "Enzyme/*.h", "Enzyme/TypeAnalysis/*.h", @@ -159,10 +175,13 @@ cc_library( "-DENZYME_VERSION_MAJOR=0", "-DENZYME_VERSION_MINOR=0", "-DENZYME_VERSION_PATCH=79", + "-Wno-unused-variable", + "-Wno-return-type", ], data = ["@llvm-project//clang:builtin_headers_gen"], visibility = ["//visibility:public"], deps = [ + "include-utils", ":binop-derivatives", ":blas-attributor", ":blas-derivatives", @@ -194,7 +213,7 @@ cc_library( "@llvm-project//llvm:TransformUtils", "@llvm-project//llvm:config", ], - alwayslink = 1 + alwayslink = 1, ) cc_binary( @@ -223,6 +242,8 @@ cc_binary( srcs = ["Enzyme/eopt.cpp"], deps = [ ":EnzymeStatic", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", "@llvm-project//llvm:opt-driver", ], ) @@ -230,16 +251,16 @@ cc_binary( td_library( name = "EnzymeDialectTdFiles", srcs = [ - "Enzyme/MLIR/Dialect/Dialect.td", + "Enzyme/MLIR/Dialect/Dialect.td", ], - deps = [ + deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:LoopLikeInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", "@llvm-project//mlir:ViewLikeInterfaceTdFiles", - "@llvm-project//mlir:FunctionInterfacesTdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:LoopLikeInterfaceTdFiles", - ] + ], ) gentbl_cc_library( @@ -277,9 +298,9 @@ td_library( name = "EnzymePassesTdFiles", srcs = [ ], - deps = [ + deps = [ "@llvm-project//mlir:PassBaseTdFiles", - ] + ], ) gentbl_cc_library( @@ -349,7 +370,6 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) - gentbl_cc_library( name = "EnzymeTypeInterfacesIncGen", tbl_outs = [ @@ -384,6 +404,30 @@ gentbl_cc_library( deps = [":EnzymeDialectTdFiles"], ) +td_library( + name = "ImplementationsCommonTdFiles", + srcs = [ + "Enzyme/MLIR/Implementations/Common.td", + ], +) + +gentbl( + name = "affine-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/AffineDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/AffineDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/AffineDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + gentbl( name = "arith-derivatives", tbl_outs = [( @@ -392,7 +436,129 @@ gentbl( )], tblgen = ":enzyme-tblgen", td_file = "Enzyme/MLIR/Implementations/ArithDerivatives.td", - td_srcs = ["Enzyme/MLIR/Implementations/ArithDerivatives.td"], + td_srcs = [ + "Enzyme/MLIR/Implementations/ArithDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "llvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/LLVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/LLVMDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/LLVMDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "nvvm-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/NVVMDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/NVVMDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/NVVMDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "scf-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/SCFDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/SCFDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/SCFDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "cf-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/CFDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/CFDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/CFDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "memref-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/MemRefDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/MemRefDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/MemRefDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "math-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/MathDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/MathDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/MathDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + +gentbl( + name = "func-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/FuncDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/FuncDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/FuncDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], deps = [ ":enzyme-tblgen", ], @@ -414,83 +580,98 @@ cc_library( "Enzyme/MLIR/Analysis/*.h", "Enzyme/MLIR/Implementations/*.h", "Enzyme/Utils.h", - "Enzyme/TypeAnalysis/*.h" + "Enzyme/TypeAnalysis/*.h", ]), - includes = ["Enzyme/MLIR", "Enzyme"], + includes = [ + "Enzyme", + "Enzyme/MLIR", + ], visibility = ["//visibility:public"], deps = [ - ":arith-derivatives", + ":EnzymeAttributesIncGen", + ":EnzymeEnumsIncGen", + ":EnzymeOpInterfacesIncGen", ":EnzymeOpsIncGen", ":EnzymePassesIncGen", - ":EnzymeTypesIncGen", - ":EnzymeEnumsIncGen", - ":EnzymeAttributesIncGen", - ":EnzymeTypeInterfacesIncGen", - ":EnzymeOpInterfacesIncGen", + ":EnzymeTypeInterfacesIncGen", + ":EnzymeTypesIncGen", + ":affine-derivatives", + ":arith-derivatives", + ":cf-derivatives", + ":llvm-derivatives", + ":func-derivatives", + ":math-derivatives", + ":memref-derivatives", + ":nvvm-derivatives", + ":scf-derivatives", + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//llvm:config", "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:ConversionPasses", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:CastInterfaces", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgStructuredOpsIncGen", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:ViewLikeInterface", ], ) cc_binary( name = "enzymemlir-opt", srcs = ["Enzyme/MLIR/enzymemlir-opt.cpp"], - visibility = ["//visibility:public"], includes = ["Enzyme/MLIR"], + visibility = ["//visibility:public"], deps = [ ":EnzymeMLIR", - "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:ConversionPasses", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:OpenMPDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Transforms", ], ) -# Generates lit config input file by applying path placeholder substitutions -# similar to the configure_lit_site_cfg CMake macro. -expand_template( - name = "lit_site_cfg_py", - testonly = True, - out = "test/lit.site.cfg.py", - substitutions = { - "@LLVM_VERSION_MAJOR@": "18", - "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", - "@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"), - "@ENZYME_SOURCE_DIR@": "", - "@ENZYME_BINARY_DIR@": "", - "@TARGET_TRIPLE@": "", - "@TARGETS_TO_BUILD@": "ALL", - "@LLVM_SHLIBEXT@": ".so", - }, - template = "test/lit.site.cfg.py.in", - visibility = ["//visibility:private"], -) - -[ - lit_test( - name = "%s.test" % src, - srcs = [src], - data = [ - ":test/lit.cfg.py", - ":test/lit.site.cfg.py", - "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:count", - "@llvm-project//llvm:not", - "@llvm-project//llvm:lli", - ":enzyme-opt", - "@llvm-project//clang:builtin_headers_gen", - ":enzyme-clang", - ":enzyme-clang++", - ":enzymemlir-opt" - ] + glob(["test/**/*.h"]) - ) - for src in glob(["test/**/*.mlir", "test/Integration/**/*.c", "test/Integration/**/.cpp"], exclude=["test/**/*omp*.c"]) -] +exports_files(["run_lit.sh"]) + diff --git a/enzyme/CMakeLists.txt b/enzyme/CMakeLists.txt index 13a9b4973c54..2ab67dd3d2fb 100644 --- a/enzyme/CMakeLists.txt +++ b/enzyme/CMakeLists.txt @@ -2,6 +2,8 @@ cmake_minimum_required(VERSION 3.13) project(Enzyme) include(CMakePackageConfigHelpers) +include(CheckIncludeFile) +include(CheckIncludeFileCXX) set(ENZYME_MAJOR_VERSION 0) set(ENZYME_MINOR_VERSION 0) @@ -13,8 +15,9 @@ add_definitions(-DENZYME_VERSION_MAJOR=${ENZYME_MAJOR_VERSION}) add_definitions(-DENZYME_VERSION_MINOR=${ENZYME_MINOR_VERSION}) add_definitions(-DENZYME_VERSION_PATCH=${ENZYME_PATCH_VERSION}) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else") +SET(CMAKE_CXX_FLAGS "-Wall -fno-rtti ${CMAKE_CXX_FLAGS} -Werror=unused-variable -Werror=dangling-else -Werror=unused-but-set-variable -Werror=return-type -Werror=nonnull") SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -ggdb") SET(CMAKE_CXX_FLAGS_RELEASE "-O2") @@ -264,6 +267,11 @@ string(REPLACE "};\n}" "};\n}}" INPUT_TEXT "${INPUT_TEXT}") string(REPLACE "const SCEV* S;\n};\n" "const SCEV* S;\n};\n}\n" INPUT_TEXT "${INPUT_TEXT}") endif() +find_library(MPFR_LIB_PATH mpfr) +CHECK_INCLUDE_FILE("mpfr.h" HAS_MPFR_H) +message("MPFR lib: " ${MPFR_LIB_PATH}) +message("MPFR header: " ${HAS_MPFR_H}) + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/include/SCEV/ScalarEvolutionExpander.h" "${INPUT_TEXT}") include_directories("${CMAKE_CURRENT_BINARY_DIR}/include") diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index b3eb29626d8e..3c5a4ae7e964 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -104,20 +104,7 @@ cl::opt EnzymeEnableRecursiveHypotheses( #include // clang-format off -const char *KnownInactiveFunctionsStartingWith[] = { - "f90io", - "$ss5print", - "_ZTv0_n24_NSoD", //"1Ev, 0Ev - "_ZNSt16allocator_traitsISaIdEE10deallocate", - "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev", -}; - -const char *KnownInactiveFunctionsContains[] = { - "__enzyme_float", "__enzyme_double", "__enzyme_integer", - "__enzyme_pointer"}; - -const StringSet<> InactiveGlobals = { +static const StringSet<> InactiveGlobals = { "small_typeof", "ompi_request_null", "ompi_mpi_double", @@ -128,9 +115,11 @@ const StringSet<> InactiveGlobals = { "_ZSt3cin", "_ZSt4cout", "_ZNSt3__14coutE", + "_ZNSt3__15wcoutE", "_ZNSt3__113basic_ostreamIcNS_11char_traitsIcEEE6sentryC1ERS3_", "_ZSt5wcout", "_ZSt4cerr", + "_ZNSt3__14cerrE", "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", "_ZTVSt15basic_streambufIcSt11char_traitsIcEE", "_ZTVSt9basic_iosIcSt11char_traitsIcEE", @@ -168,17 +157,24 @@ const llvm::StringMap MPIInactiveCommAllocators = { {"MPI_Comm_idup", 1}, {"MPI_Comm_join", 1}, }; +// clang-format on -// Instructions which themselves are inactive -// the returned value, however, may still be active -const StringSet<> KnownInactiveFunctionInsts = { - "__dynamic_cast", - "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base", - "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base", - "jl_ptr_to_array", - "jl_ptr_to_array_1d"}; +/// Return whether the call is always inactive by definition. +bool isInactiveCall(CallBase &CI) { + + // clang-format off +const char *KnownInactiveFunctionsStartingWith[] = { + "f90io", + "$ss5print", + "_ZTv0_n24_NSoD", //"1Ev, 0Ev + "_ZNSt16allocator_traitsISaIdEE10deallocate", + "_ZNSaIcED1Ev", + "_ZNSaIcEC1Ev", +}; + +const char *KnownInactiveFunctionsContains[] = { + "__enzyme_float", "__enzyme_double", "__enzyme_integer", + "__enzyme_pointer"}; const StringSet<> KnownInactiveFunctions = { "mpfr_greater_p", @@ -290,13 +286,17 @@ const StringSet<> KnownInactiveFunctions = { "cuDevicePrimaryCtxRetain", "floor", "floorf", - "floorl" + "floorl", + "\01_fopen", + "fopen", + "fclose", }; const std::set KnownInactiveIntrinsics = { #if LLVM_VERSION_MAJOR >= 12 Intrinsic::experimental_noalias_scope_decl, #endif + Intrinsic::objectsize, Intrinsic::floor, Intrinsic::ceil, Intrinsic::trunc, @@ -395,6 +395,8 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { // libc++ + "std::__1::locale", + "std::__1::ios_base", "std::__1::basic_string", "std::__1::__do_string_hash", "std::__1::hash", @@ -411,17 +413,125 @@ const char *DemangledKnownInactiveFunctionsStartingWith[] = { "std::__1::shuffle_order_engine", "std::__1::basic_streambuf", "std::__1::basic_stringbuf", + "std::__1::basic_istream", + "std::__1::basic_filebuf", + "std::__1::basic_iostream", + "std::__1::basic_ios", + "virtual thunk to std::__1::basic_istream", + "virtual thunk to std::__1::basic_ostream", "std::__detail::_Prime_rehash_policy", "std::__detail::_Hash_code_base", + }; -// clang-format on + // clang-format on + + if (CI.hasFnAttr("enzyme_inactive")) + return true; + + if (auto iasm = dyn_cast(CI.getCalledOperand())) { + if (StringRef(iasm->getAsmString()).contains("exit") || + StringRef(iasm->getAsmString()).contains("cpuid")) + return true; + } + + if (auto F = getFunctionFromCall(&CI)) { + if (F->hasFnAttribute("enzyme_inactive")) { + return true; + } + if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { + return true; + } + } + + auto Name = getFuncNameFromCall(&CI); + + std::string demangledName = llvm::demangle(Name.str()); + auto dName = StringRef(demangledName); + for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { + if (startsWith(dName, FuncName)) { + return true; + } + } + + for (auto FuncName : KnownInactiveFunctionsStartingWith) { + if (startsWith(Name, FuncName)) { + return true; + } + } + + for (auto FuncName : KnownInactiveFunctionsContains) { + if (Name.contains(FuncName)) { + return true; + } + } + if (KnownInactiveFunctions.count(Name)) { + return true; + } + + if (MPIInactiveCommAllocators.find(Name) != MPIInactiveCommAllocators.end()) { + return true; + } + Intrinsic::ID ID; + if (isMemFreeLibMFunction(Name, &ID)) + if (KnownInactiveIntrinsics.count(ID)) { + return true; + } + + // Copies of size 1 are inactive [cannot move differentiable data in one byte] + if (auto MTI = dyn_cast(&CI)) { + if (auto sz = dyn_cast(MTI->getOperand(2))) { + if (sz->getValue() == 1) + return true; + } + } + + return false; +} + +bool isInactiveCallInst(CallBase &CB, llvm::TargetLibraryInfo &TLI) { + // clang-format off +// Instructions which themselves are inactive +// the returned value, however, may still be active +static const StringSet<> KnownInactiveFunctionInsts = { + "__dynamic_cast", + "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_decrementPSt18_Rb_tree_node_base", + "_ZSt18_Rb_tree_incrementPSt18_Rb_tree_node_base", + "jl_ptr_to_array", + "jl_ptr_to_array_1d"}; + // clang-format on + if (isInactiveCall(CB)) + return true; + if (CB.hasFnAttr("enzyme_inactive_inst")) { + return true; + } + auto called = getFunctionFromCall(&CB); + + if (called) { + if (called->hasFnAttribute("enzyme_inactive_inst")) { + return true; + } + } + + auto funcName = getFuncNameFromCall(&CB); + if (KnownInactiveFunctionInsts.count(funcName)) + return true; + + if (isAllocationFunction(funcName, TLI) || + isDeallocationFunction(funcName, TLI)) { + return true; + } + + return false; +} /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { assert(directions & DOWN); - if (CI->hasFnAttr("enzyme_inactive")) + if (isInactiveCall(*CI)) return true; auto F = getFunctionFromCall(CI); @@ -451,10 +561,6 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (F == nullptr) return false; - if (F->hasFnAttribute("enzyme_inactive")) { - return true; - } - auto Name = getFuncNameFromCall(CI); // Only the 1-th arg impacts activity @@ -466,43 +572,6 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI)) return true; - std::string demangledName = llvm::demangle(Name.str()); - auto dName = StringRef(demangledName); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return true; - } - } - if (demangledName == Name.str()) { - // Either demangeling failed - // or they are equal but matching failed - // if (!startsWith(Name, "llvm.")) - // llvm::errs() << "matching failed: " << Name.str() << " " - // << demangledName << "\n"; - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(Name, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (Name.contains(FuncName)) { - return true; - } - } - if (KnownInactiveFunctions.count(Name)) { - return true; - } - - if (MPIInactiveCommAllocators.find(Name) != MPIInactiveCommAllocators.end()) { - return true; - } - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return true; - } - /// Only the first argument (magnitude) of copysign is active if (F->getIntrinsicID() == Intrinsic::copysign && CI->getArgOperand(0) != val) { @@ -549,6 +618,8 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { static inline void propagateArgumentInformation( TargetLibraryInfo &TLI, CallInst &CI, llvm::function_ref propagateFromOperand) { + if (isInactiveCall(CI)) + return; // These functions are known to only have the first argument impact // the activity of the call instruction @@ -596,12 +667,6 @@ static inline void propagateArgumentInformation( return; } - // Certain intrinsics are inactive by definition - // and have nothing to propagate. - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return; - } - if (F->getIntrinsicID() == Intrinsic::memcpy || F->getIntrinsicID() == Intrinsic::memmove) { propagateFromOperand(CI.getOperand(0)); @@ -731,13 +796,6 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, ActiveInstructions.insert(I); return false; } - if (CI->hasFnAttr("enzyme_inactive") || - CI->hasFnAttr("enzyme_inactive_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } auto called = getFunctionFromCall(CI); if (called) { @@ -748,27 +806,17 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, ActiveInstructions.insert(I); return false; } - if (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_inst")) { - if (EnzymePrintActivity) - llvm::errs() << "forced inactive " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } } - if (KnownInactiveFunctionInsts.count(getFuncNameFromCall(CI))) { + if (isInactiveCallInst(*CI, TLI)) { + if (EnzymePrintActivity) + llvm::errs() << "known inactive instruction from call " << *I << "\n"; InsertConstantInstruction(TR, I); return true; } } if (auto II = dyn_cast(I)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - if (EnzymePrintActivity) - llvm::errs() << "known inactive intrinsic " << *I << "\n"; - InsertConstantInstruction(TR, I); - return true; - } else if (isIntelSubscriptIntrinsic(*II)) { + if (isIntelSubscriptIntrinsic(*II)) { // The intrinsic "llvm.intel.subscript" does not propogate deriviative // information directly. But its returned pointer may be active. InsertConstantInstruction(TR, I); @@ -865,7 +913,8 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR, if (isMemFreeLibMFunction(funcName)) { noActiveWrite = true; } else if (funcName == "frexp" || funcName == "frexpf" || - funcName == "frexpl") { + funcName == "frexpl" || funcName == "modf" || + funcName == "modff" || funcName == "modfl") { noActiveWrite = true; } } @@ -1013,9 +1062,11 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } assert(TR.getFunction() == I->getParent()->getParent()); } +#ifndef NDEBUG if (auto Arg = dyn_cast(Val)) { assert(TR.getFunction() == Arg->getParent()); } +#endif // Void values are definitionally inactive if (Val->getType()->isVoidTy()) @@ -1077,13 +1128,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - if (auto II = dyn_cast(Val)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - InsertConstantValue(TR, Val); - return true; - } - } - // All arguments must be marked constant/nonconstant ahead of time if (isa(Val) && !cast(Val)->hasByValAttr()) { llvm::errs() << *(cast(Val)->getParent()) << "\n"; @@ -1342,6 +1386,12 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } } + if (isInactiveCall(*CI)) { + if (EnzymePrintActivity) + llvm::errs() << "known inactive val from call" << *Val << "\n"; + InsertConstantValue(TR, Val); + return true; + } } if (auto BO = dyn_cast(Val)) { // x & 0b100000 is definitionally inactive @@ -1380,6 +1430,12 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { if (containsPointer && !isValuePotentiallyUsedAsPointer(Val)) { containsPointer = false; + if (auto Arg = dyn_cast(Val)) { + assert(Arg->hasByValAttr()); + (void)Arg; + InsertConstantValue(TR, Val); + return true; + } } // We do this pointer dance here to ensure that any derived pointers from @@ -1530,8 +1586,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } } } else if (auto op = dyn_cast(TmpOrig)) { - if (op->hasFnAttr("enzyme_inactive") || - op->hasFnAttr("enzyme_inactive_val") || + if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") || op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, "enzyme_inactive")) { InsertConstantValue(TR, Val); @@ -1543,8 +1598,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { StringRef funcName = getFuncNameFromCall(op); if (called && - (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_val") || + (called->hasFnAttribute("enzyme_inactive_val") || called->getAttributes().hasAttribute( llvm::AttributeList::ReturnIndex, "enzyme_inactive"))) { InsertConstantValue(TR, Val); @@ -1558,45 +1612,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { return true; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - } - - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - - if (called && called->getIntrinsicID() == Intrinsic::trap) { - InsertConstantValue(TR, Val); - insertConstantsFrom(TR, *UpHypothesis); - return true; - } - // If requesting empty unknown functions to be considered inactive, // abide by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -1851,57 +1866,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { // If this is a malloc or free, this doesn't impact the activity if (auto CI = dyn_cast(I)) { - if (CI->hasFnAttr("enzyme_inactive") || - CI->hasFnAttr("enzyme_inactive_inst")) + if (isInactiveCallInst(*CI, TLI)) return false; - if (auto iasm = dyn_cast(CI->getCalledOperand())) { - if (StringRef(iasm->getAsmString()).contains("exit") || - StringRef(iasm->getAsmString()).contains("cpuid")) - return false; - } - - auto F = getFunctionFromCall(CI); StringRef funcName = getFuncNameFromCall(CI); - - if (F && (F->hasFnAttribute("enzyme_inactive") || - F->hasFnAttribute("enzyme_inactive_inst"))) { - return false; - } - if (isAllocationFunction(funcName, TLI) || - isDeallocationFunction(funcName, TLI)) { - return false; - } - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - return false; - } - if (KnownInactiveFunctionInsts.count(funcName)) { - return false; - } if (isMemFreeLibMFunction(funcName)) { return false; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return false; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - return false; - } - } - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - return false; - } - } - if (funcName == "__cxa_guard_acquire" || funcName == "__cxa_guard_release" || funcName == "__cxa_guard_abort" || funcName == "posix_memalign" || @@ -1911,12 +1883,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { funcName == "cudaMallocFromPoolAsync") { return false; } - - if (F) { - if (KnownInactiveIntrinsics.count(F->getIntrinsicID())) { - return false; - } - } } Value *memval = Val; @@ -2103,14 +2069,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { << "\n"; if (auto SI = dyn_cast(I)) { bool cop = !Hypothesis->isConstantValue(TR, SI->getValueOperand()); - bool cop2 = !Hypothesis->isConstantValue(TR, SI->getPointerOperand()); + // bool cop2 = !Hypothesis->isConstantValue(TR, + // SI->getPointerOperand()); if (EnzymePrintActivity) - llvm::errs() << " -- store potential activity: " << (int)cop << "," - << (int)cop2 << "," + llvm::errs() << " -- store potential activity: " << (int)cop << " - " << *SI << " of " << " Val=" << *Val << "\n"; potentialStore = I; - if (cop && cop2) + if (cop) // && cop2) potentiallyActiveStore = SI; } else if (auto MTI = dyn_cast(I)) { bool cop = !Hypothesis->isConstantValue(TR, MTI->getArgOperand(1)); @@ -2337,6 +2303,9 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { // this value is inactive, we are inactive Since we won't look at uses to // prove, we can inductively assume this is inactive if (directions & UP) { + if (!UpHypothesis) + UpHypothesis = + std::shared_ptr(new ActivityAnalyzer(*this, UP)); if (directions == UP && !isa(Val)) { if (isInstructionInactiveFromOrigin(TR, Val, true)) { InsertConstantValue(TR, Val); @@ -2352,8 +2321,6 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { } } } else { - UpHypothesis = - std::shared_ptr(new ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantValues.insert(Val); if (UpHypothesis->isInstructionInactiveFromOrigin(TR, Val, true)) { insertConstantsFrom(TR, *UpHypothesis); @@ -2532,8 +2499,10 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, } if (auto op = dyn_cast(inst)) { - if (op->hasFnAttr("enzyme_inactive") || - op->hasFnAttr("enzyme_inactive_val")) { + if (isInactiveCall(*op)) + return true; + + if (op->hasFnAttr("enzyme_inactive_val")) { return true; } // Calls to print/assert/cxa guard are definitionally inactive @@ -2542,8 +2511,7 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, StringRef funcName = getFuncNameFromCall(op); auto called = getFunctionFromCall(op); - if (called && (called->hasFnAttribute("enzyme_inactive") || - called->hasFnAttribute("enzyme_inactive_val"))) { + if (called && (called->hasFnAttribute("enzyme_inactive_val"))) { return true; } if (funcName == "free" || funcName == "_ZdlPv" || funcName == "_ZdlPvm" || @@ -2551,37 +2519,6 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, return true; } - auto dName = demangle(funcName.str()); - for (auto FuncName : DemangledKnownInactiveFunctionsStartingWith) { - if (startsWith(dName, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsStartingWith) { - if (startsWith(funcName, FuncName)) { - return true; - } - } - - for (auto FuncName : KnownInactiveFunctionsContains) { - if (funcName.contains(FuncName)) { - return true; - } - } - - if (KnownInactiveFunctions.count(funcName) || - MPIInactiveCommAllocators.find(funcName) != - MPIInactiveCommAllocators.end()) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions - << ") up-knowninactivecall " << *inst << "\n"; - return true; - } - - if (called && called->getIntrinsicID() == Intrinsic::trap) - return true; - // If requesting empty unknown functions to be considered inactive, abide // by those rules if (called && EnzymeEmptyFnInactive && called->empty() && @@ -2603,12 +2540,6 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR, } // Intrinsics known always to be inactive if (auto II = dyn_cast(inst)) { - if (KnownInactiveIntrinsics.count(II->getIntrinsicID())) { - if (EnzymePrintActivity) - llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - << *inst << "\n"; - return true; - } if (isIntelSubscriptIntrinsic(*II)) { // The only argument that can make an llvm.intel.subscript intrinsic // active is the pointer operand diff --git a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp b/enzyme/Enzyme/ActivityAnalysisPrinter.cpp index cecb6f833c5b..684fcd89e04d 100644 --- a/enzyme/Enzyme/ActivityAnalysisPrinter.cpp +++ b/enzyme/Enzyme/ActivityAnalysisPrinter.cpp @@ -95,14 +95,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) { if (a.getType()->isFPOrFPVectorTy()) { dt = ConcreteType(a.getType()->getScalarType()); } else if (a.getType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR >= 17 -#else - auto et = a.getType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR < 17 +#if LLVM_VERSION_MAJOR >= 13 + if (a.getContext().supportsTypedPointers()) { +#endif + auto et = a.getType()->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 13 } +#endif #endif } else if (a.getType()->isIntOrIntVectorTy()) { dt = ConcreteType(BaseType::Integer); @@ -119,14 +124,19 @@ bool printActivityAnalysis(llvm::Function &F, TargetLibraryInfo &TLI) { if (F.getReturnType()->isFPOrFPVectorTy()) { dt = ConcreteType(F.getReturnType()->getScalarType()); } else if (F.getReturnType()->isPointerTy()) { -#if LLVM_VERSION_MAJOR >= 17 -#else - auto et = F.getReturnType()->getPointerElementType(); - if (et->isFPOrFPVectorTy()) { - dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); - } else if (et->isPointerTy()) { - dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); +#if LLVM_VERSION_MAJOR < 17 +#if LLVM_VERSION_MAJOR >= 13 + if (F.getContext().supportsTypedPointers()) { +#endif + auto et = F.getReturnType()->getPointerElementType(); + if (et->isFPOrFPVectorTy()) { + dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1, nullptr); + } else if (et->isPointerTy()) { + dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1, nullptr); + } +#if LLVM_VERSION_MAJOR >= 13 } +#endif #endif } else if (F.getReturnType()->isIntOrIntVectorTy()) { dt = ConcreteType(BaseType::Integer); diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index c6d558cf5ba0..a3abf6991e95 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -47,9 +47,7 @@ #define DEBUG_TYPE "enzyme" // Helper instruction visitor that generates adjoints -template -class AdjointGenerator - : public llvm::InstVisitor> { +class AdjointGenerator : public llvm::InstVisitor { private: // Type of code being generated (forward, reverse, or both) const DerivativeMode Mode; @@ -63,7 +61,7 @@ class AdjointGenerator const std::map> overwritten_args_map; const llvm::SmallPtrSetImpl *returnuses; - AugmentedReturnType augmentedReturn; + const AugmentedReturn *augmentedReturn; const std::map *replacedReturns; const llvm::SmallPtrSetImpl &unnecessaryValues; @@ -83,7 +81,7 @@ class AdjointGenerator const std::map> overwritten_args_map, const llvm::SmallPtrSetImpl *returnuses, - AugmentedReturnType augmentedReturn, + const AugmentedReturn *augmentedReturn, const std::map *replacedReturns, const llvm::SmallPtrSetImpl &unnecessaryValues, const llvm::SmallPtrSetImpl @@ -102,7 +100,7 @@ class AdjointGenerator using namespace llvm; assert(TR.getFunction() == gutils->oldFunc); - for (auto &pair : TR.analyzer.analysis) { + for (auto &pair : TR.analyzer->analysis) { if (auto in = dyn_cast(pair.first)) { if (in->getParent()->getParent() != gutils->oldFunc) { llvm::errs() << "inf: " << *in->getParent()->getParent() << "\n"; @@ -410,6 +408,7 @@ class AdjointGenerator constantval |= gutils->isConstantValue(&I); Type *type = gutils->getShadowType(I.getType()); + (void)type; auto *newi = dyn_cast(gutils->getNewFromOriginal(&I)); @@ -623,6 +622,7 @@ class AdjointGenerator if (primalNeededInReverse) { inst = gutils->cacheForReverse(BuilderZ, newi, getIndex(&I, CacheType::Self, BuilderZ)); + (void)inst; assert(inst->getType() == type); if (Mode == DerivativeMode::ReverseModeGradient || @@ -758,7 +758,7 @@ class AdjointGenerator auto alignment = LI.getAlign(); auto &DL = gutils->newFunc->getParent()->getDataLayout(); - bool constantval = parseTBAA(LI, DL, nullptr).Inner0().isIntegral(); + bool constantval = parseTBAA(LI, DL, nullptr)[{-1}].isIntegral(); visitLoadLike(LI, alignment, constantval); eraseIfUnused(LI); } @@ -900,6 +900,9 @@ class AdjointGenerator setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), BuilderZ); } + gutils->replaceAWithB(gutils->getNewFromOriginal(&I), + UndefValue::get(I.getType())); + eraseIfUnused(I, /*erase*/ true, /*check*/ false); return; } @@ -994,7 +997,7 @@ class AdjointGenerator NewI->setMetadata(LLVMContext::MD_noalias, noscope); bool constantval = gutils->isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); IRBuilder<> BuilderZ(NewI); BuilderZ.setFastMathFlags(getFast()); @@ -1053,7 +1056,7 @@ class AdjointGenerator } Value *diff = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && constantval) { + if (!EnzymeRuntimeActivityCheck && constantval) { if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { if (!isa(orig_val) && !isa(orig_val)) { @@ -1061,9 +1064,12 @@ class AdjointGenerator raw_string_ostream ss(str); ss << "Mismatched activity for: " << I << " const val: " << *orig_val; - diff = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils, - wrap(orig_val), wrap(&BuilderZ))); + if (CustomErrorHandler) + diff = unwrap(CustomErrorHandler( + str.c_str(), wrap(&I), ErrorType::MixedActivityError, gutils, + wrap(orig_val), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", I, ss.str()); } } } @@ -1282,7 +1288,7 @@ class AdjointGenerator Value *valueop = nullptr; if (constantval) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler) { + if (!EnzymeRuntimeActivityCheck) { if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) { if (!isa(orig_val) && !isa(orig_val)) { @@ -1290,9 +1296,12 @@ class AdjointGenerator raw_string_ostream ss(str); ss << "Mismatched activity for: " << I << " const val: " << *orig_val; - valueop = unwrap(CustomErrorHandler( - str.c_str(), wrap(&I), ErrorType::MixedActivityError, - gutils, wrap(orig_val), wrap(&BuilderZ))); + if (CustomErrorHandler) + valueop = unwrap(CustomErrorHandler( + str.c_str(), wrap(&I), ErrorType::MixedActivityError, + gutils, wrap(orig_val), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", I, ss.str()); } } } @@ -1535,7 +1544,7 @@ class AdjointGenerator lc) && gutils->getNewFromOriginal(P0->getParent()) == lc.header) { SmallVector Latches; - gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (&SI != P0->getIncomingValueForBlock(Latch)) { @@ -2199,7 +2208,7 @@ class AdjointGenerator lc) && gutils->getNewFromOriginal(P0->getParent()) == lc.header) { SmallVector Latches; - gutils->OrigLI.getLoopFor(P0->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(P0->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (&BO != P0->getIncomingValueForBlock(Latch)) { @@ -3040,8 +3049,8 @@ class AdjointGenerator } if (!vd.isKnownPastPointer()) { if (looseTypeAnalysis) { - if (auto CI = dyn_cast(MS.getOperand(0))) { #if LLVM_VERSION_MAJOR < 17 + if (auto CI = dyn_cast(MS.getOperand(0))) { if (auto PT = dyn_cast(CI->getSrcTy())) { auto ET = PT->getPointerElementType(); while (1) { @@ -3070,8 +3079,8 @@ class AdjointGenerator goto known; } } -#endif } +#endif if (auto gep = dyn_cast(MS.getOperand(0))) { if (auto AT = dyn_cast(gep->getSourceElementType())) { if (AT->getElementType()->isIntegerTy()) { @@ -3271,10 +3280,18 @@ class AdjointGenerator return; } + // memcpy of size 1 cannot move differentiable data [single byte copy] + if (auto ci = dyn_cast(new_size)) { + if (ci->getValue() == 1) { + eraseIfUnused(MTI); + return; + } + } + // copying into nullptr is invalid (not sure why it exists here), but we // shouldn't do it in reverse pass or shadow if (isa(orig_dst) || - TR.query(orig_dst).Inner0() == BaseType::Anything) { + TR.query(orig_dst)[{-1}] == BaseType::Anything) { eraseIfUnused(MTI); return; } @@ -3312,8 +3329,8 @@ class AdjointGenerator if (!vd.isKnownPastPointer()) { if (looseTypeAnalysis) { for (auto val : {orig_dst, orig_src}) { - if (auto CI = dyn_cast(val)) { #if LLVM_VERSION_MAJOR < 17 + if (auto CI = dyn_cast(val)) { if (auto PT = dyn_cast(CI->getSrcTy())) { auto ET = PT->getPointerElementType(); while (1) { @@ -3342,8 +3359,8 @@ class AdjointGenerator goto known; } } -#endif } +#endif if (auto gep = dyn_cast(val)) { if (auto AT = dyn_cast(gep->getSourceElementType())) { if (AT->getElementType()->isIntegerTy()) { @@ -3353,6 +3370,29 @@ class AdjointGenerator } } } + // If the type is known, but outside of the known range + // (but the memcpy size is a variable), attempt to use + // the first type out of range as the memcpy type. + if (size == 1 && !isa(new_size)) { + for (auto ptr : {orig_dst, orig_src}) { + vd = TR.query(ptr).Data0().ShiftIndices(DL, 0, -1, 0); + if (vd.isKnownPastPointer()) { + ConcreteType mv(BaseType::Unknown); + size_t minInt = 0xFFFFFFFF; + for (const auto &pair : vd.getMapping()) { + if (pair.first.size() != 1) + continue; + if (minInt < (size_t)pair.first[0]) + continue; + minInt = pair.first[0]; + mv = pair.second; + } + assert(mv != BaseType::Unknown); + vd.insert({0}, mv); + goto known; + } + } + } if (errorIfNoType) EmitWarning("CannotDeduceType", MTI, "failed to deduce type of copy ", MTI); @@ -3368,6 +3408,7 @@ class AdjointGenerator &TR.analyzer, nullptr, wrap(&BuilderZ)); } else { ss << "\n"; + ss << *gutils->oldFunc << "\n"; TR.dump(ss); EmitFailure("CannotDeduceType", MTI.getDebugLoc(), &MTI, ss.str()); } @@ -3667,7 +3708,7 @@ class AdjointGenerator auto align0 = cast(I.getOperand(1))->getZExtValue(); auto align = MaybeAlign(align0); auto &DL = gutils->newFunc->getParent()->getDataLayout(); - bool constantval = parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + bool constantval = parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); visitLoadLike(I, align, constantval, /*mask*/ gutils->getNewFromOriginal(I.getOperand(2)), /*orig_maskInit*/ I.getOperand(3)); @@ -3738,6 +3779,7 @@ class AdjointGenerator setDiffe(&I, Constant::getNullValue(gutils->getShadowType(I.getType())), Builder2); } + (void)vdiff; switch (ID) { @@ -4046,7 +4088,7 @@ class AdjointGenerator assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); } else { - assert(TR.query(call.getArgOperand(i)).Inner0().isFloat()); + assert(TR.query(call.getArgOperand(i))[{-1}].isFloat()); OutTypes.push_back(call.getArgOperand(i)); OutFPTypes.push_back(argType); assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF || @@ -4113,7 +4155,7 @@ class AdjointGenerator if (called) { subdata = &gutils->Logic.CreateAugmentedPrimal( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*return is used*/ false, /*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false, gutils->getWidth(), @@ -4329,7 +4371,7 @@ class AdjointGenerator : nullptr, .forceAnonymousTape = false, .typeInfo = nextTypeInfo}, - TR.analyzer.interprocedural, subdata, + TR.analyzer->interprocedural, subdata, /*omp*/ true); if (subdata->returns.find(AugmentedStruct::Tape) != @@ -4372,7 +4414,6 @@ class AdjointGenerator } } } - size_t freeCount = 0; for (auto LI : geps) { CallInst *freeCall = nullptr; for (auto LU : LI->users()) { @@ -4400,7 +4441,6 @@ class AdjointGenerator } if (freeCall) { freeCall->eraseFromParent(); - freeCount++; } } } @@ -4571,7 +4611,7 @@ class AdjointGenerator EmitFailure("CannotDeduceType", call.getDebugLoc(), &call, "failed to deduce type of copy ", call); } -#if LLVM_VERSION_MAJOR < 18 +#if LLVM_VERSION_MAJOR < 17 knownF: #endif unsigned start = 0; @@ -4717,10 +4757,11 @@ class AdjointGenerator // call.getParamAttr(i, Attribute::StructRet).getValueAsType())); #endif } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - structAttrs[args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); + } for (auto ty : PrimalParamAttrsToPreserve) if (call.getAttributes().hasParamAttr(i, ty)) { auto attr = call.getAttributes().getParamAttr(i, ty); @@ -4775,15 +4816,16 @@ class AdjointGenerator structAttrs[args.size()].push_back(attr); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - if (gutils->getWidth() == 1) { - structAttrs[args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } else { - structAttrs[args.size()].push_back( - Attribute::get(call.getContext(), "enzyme_sret_v")); + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + if (gutils->getWidth() == 1) { + structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); + } else if (attr == std::string("enzymejl_returnRoots")) { + structAttrs[args.size()].push_back( + Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + } } - } if (call.paramHasAttr(i, Attribute::StructRet)) { if (gutils->getWidth() == 1) { structAttrs[args.size()].push_back( @@ -4835,18 +4877,9 @@ class AdjointGenerator } } Value *tape = nullptr; -#if LLVM_VERSION_MAJOR >= 16 - if (tapeIdx.has_value()) -#else - if (tapeIdx.hasValue()) -#endif - { + if (tapeIdx) { -#if LLVM_VERSION_MAJOR >= 16 - auto idx = tapeIdx.value(); -#else - auto idx = tapeIdx.getValue(); -#endif + auto idx = *tapeIdx; FunctionType *FT = subdata->fn->getFunctionType(); tape = BuilderZ.CreatePHI( @@ -4868,7 +4901,7 @@ class AdjointGenerator if (called) { newcalled = gutils->Logic.CreateForwardDiff( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*returnValue*/ subretused, Mode, ((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(), tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args, @@ -5019,10 +5052,11 @@ class AdjointGenerator if (call.isByValArgument(i)) { preByVal[pre_args.size()] = call.getParamByValType(i); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - structAttrs[pre_args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + structAttrs[pre_args.size()].push_back(call.getParamAttr(i, attr)); + } if (call.paramHasAttr(i, Attribute::StructRet)) { structAttrs[pre_args.size()].push_back( #if LLVM_VERSION_MAJOR >= 12 @@ -5115,15 +5149,17 @@ class AdjointGenerator structAttrs[pre_args.size()].push_back(attr); } - if (call.getAttributes().hasParamAttr(i, "enzymejl_returnRoots")) { - if (gutils->getWidth() == 1) { - structAttrs[pre_args.size()].push_back( - call.getParamAttr(i, "enzymejl_returnRoots")); - } else { - structAttrs[pre_args.size()].push_back( - Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + for (auto attr : {"enzymejl_returnRoots", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) + if (call.getAttributes().hasParamAttr(i, attr)) { + if (gutils->getWidth() == 1) { + structAttrs[pre_args.size()].push_back( + call.getParamAttr(i, attr)); + } else if (attr == std::string("enzymejl_returnRoots")) { + structAttrs[pre_args.size()].push_back( + Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); + } } - } if (call.paramHasAttr(i, Attribute::StructRet)) { if (gutils->getWidth() == 1) { structAttrs[pre_args.size()].push_back( @@ -5173,6 +5209,7 @@ class AdjointGenerator // Note sometimes whattype mistakenly says something should be // constant [because composed of integer pointers alone] + (void)argType; assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG || whatType(argType, Mode) == DIFFE_TYPE::CONSTANT); } else { @@ -5283,7 +5320,7 @@ class AdjointGenerator Mode == DerivativeMode::ReverseModeCombined) { subdata = &gutils->Logic.CreateAugmentedPrimal( RequestContext(&call, &BuilderZ), cast(called), - subretType, argsInverted, TR.analyzer.interprocedural, + subretType, argsInverted, TR.analyzer->interprocedural, /*return is used*/ subretused, shadowReturnUsed, nextTypeInfo, overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd); if (Mode == DerivativeMode::ReverseModePrimal) { @@ -5392,17 +5429,8 @@ class AdjointGenerator if (!augmentcall->getType()->isVoidTy()) augmentcall->setName(call.getName() + "_augmented"); -#if LLVM_VERSION_MAJOR >= 16 - if (tapeIdx.has_value()) -#else - if (tapeIdx.hasValue()) -#endif - { -#if LLVM_VERSION_MAJOR >= 16 - auto tval = tapeIdx.value(); -#else - auto tval = tapeIdx.getValue(); -#endif + if (tapeIdx) { + auto tval = *tapeIdx; tape = (tval == -1) ? augmentcall : BuilderZ.CreateExtractValue( augmentcall, {(unsigned)tval}, "subcache"); @@ -5421,11 +5449,7 @@ class AdjointGenerator Value *dcall = nullptr; assert(returnIdx); assert(augmentcall); -#if LLVM_VERSION_MAJOR >= 16 - auto rval = returnIdx.value(); -#else - auto rval = returnIdx.getValue(); -#endif + auto rval = *returnIdx; dcall = (rval < 0) ? augmentcall : BuilderZ.CreateExtractValue(augmentcall, {(unsigned)rval}); @@ -5437,8 +5461,7 @@ class AdjointGenerator assert(dcall); if (!gutils->isConstantValue(&call)) { - if (!call.getType()->isFPOrFPVectorTy() && - TR.query(&call).Inner0().isPossiblePointer()) { + if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { } else if (Mode != DerivativeMode::ReverseModePrimal) { ((DiffeGradientUtils *)gutils)->differentials[dcall] = ((DiffeGradientUtils *)gutils)->differentials[newCall]; @@ -5496,13 +5519,8 @@ class AdjointGenerator // assert(!tape); // assert(subdata); if (!tape) { -#if LLVM_VERSION_MAJOR >= 16 - assert(tapeIdx.has_value()); - auto tval = tapeIdx.value(); -#else - assert(tapeIdx.hasValue()); - auto tval = tapeIdx.getValue(); -#endif + assert(tapeIdx); + auto tval = *tapeIdx; tape = BuilderZ.CreatePHI( (tapeIdx == -1) ? FT->getReturnType() : cast(FT->getReturnType()) @@ -5561,11 +5579,7 @@ class AdjointGenerator if (Mode == DerivativeMode::ReverseModeCombined || Mode == DerivativeMode::ReverseModePrimal) { -#if LLVM_VERSION_MAJOR >= 16 - auto drval = differetIdx.value(); -#else - auto drval = differetIdx.getValue(); -#endif + auto drval = *differetIdx; newip = (drval < 0) ? augmentcall : BuilderZ.CreateExtractValue(augmentcall, @@ -5684,7 +5698,7 @@ class AdjointGenerator .additionalType = tape ? tape->getType() : nullptr, .forceAnonymousTape = false, .typeInfo = nextTypeInfo}, - TR.analyzer.interprocedural, subdata); + TR.analyzer->interprocedural, subdata); if (!newcalled) return; FT = cast(newcalled)->getFunctionType(); @@ -5907,8 +5921,7 @@ class AdjointGenerator gutils->originalToNewFn[&call] = dcall; gutils->newToOriginalFn.erase(newCall); gutils->newToOriginalFn[dcall] = &call; - if (!call.getType()->isFPOrFPVectorTy() && - TR.query(&call).Inner0().isPossiblePointer()) { + if (!call.getType()->isFPOrFPVectorTy() && TR.anyPointer(&call)) { } else { ((DiffeGradientUtils *)gutils)->differentials[dcall] = ((DiffeGradientUtils *)gutils)->differentials[newCall]; diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 70682e540ef5..f031aa7f191c 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -192,7 +192,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ ["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>], [ /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"]> - (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), + (b<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $m), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), (b<"dot"> (Rows $transa, $m, $n), adj<"y">, use<"Ax">, ConstantInt<1>)), //if (is_normal $transa) { @@ -201,7 +201,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ // call sger(m, n, alpha, x, incx, ya, incy, Aa, lda) //} /* A */ (b<"ger"> $layout, $m, $n, $alpha, (Rows $transa, (Concat adj<"y">, $x), (Concat $x, adj<"y">)), adj<"A">), - /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $n), adj<"y">, Constant<"1.0">, adj<"x">), + /* x */ (b<"gemv"> $layout, transpose<"transa">, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $m), adj<"y">, Constant<"1.0">, adj<"x">), /* beta */ (b<"dot"> (Rows $transa, $m, $n), adj<"y">, input<"y">), /* y */ (b<"scal"> (Rows $transa, $m, $n), $beta, adj<"y">) ] @@ -218,7 +218,7 @@ def ger : CallBlasPattern<(Op $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, >; //(ld $A, $transa, $lda, $m, $k) // if (cache_A) { -// ld_A = (arg_transa == 'N') ? arg_m : arg_k; +// ld_A = (arg_transa == 'N') ? arg_k : arg_m; // } else { // ld_A = arg_lda; // } @@ -229,15 +229,15 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A [ /* alpha */ (Seq<["AB", "product", "m", "n"]> - (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $m, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n + (b<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, adj<"C">, use<"AB">)), /* A */ (b<"gemm"> $layout, (Rows $transa, (Concat $transa, transpose<"transb">, $m, $k), (Concat $transb, $transa, $k, $m)), $n, $alpha, (Rows $transa, - (Concat adj<"C">, $B, (ld $B, $transb, $ldb, $k, $n)), - (Concat $B, (ld $B, $transb, $ldb, $k, $n), adj<"C">)), + (Concat adj<"C">, $B, (ld $B, $transb, $ldb, $n, $k)), + (Concat $B, (ld $B, $transb, $ldb, $n, $k), adj<"C">)), Constant<"1.0">, adj<"A">), /* B */ (b<"gemm"> $layout, (Rows $transb, @@ -245,8 +245,8 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A (Concat $transb, $transa, $n, $k)), $m, $alpha, (Rows $transb, - (Concat $A, (ld $A, $transa, $lda, $m, $k), adj<"C">), - (Concat adj<"C">, $A, (ld $A, $transa, $lda, $m, $k))), + (Concat $A, (ld $A, $transa, $lda, $k, $m), adj<"C">), + (Concat adj<"C">, $A, (ld $A, $transa, $lda, $k, $m))), Constant<"1.0">, adj<"B">), /* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">), /* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, Alloca<1>) diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 15bc6c795aa2..8c1c295b54f6 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -297,7 +297,7 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) { void *EnzymeAnalyzeTypes(EnzymeTypeAnalysisRef TAR, CFnTypeInfo CTI, LLVMValueRef F) { FnTypeInfo FTI(eunwrap(CTI, cast(unwrap(F)))); - return (void *)&((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; + return (void *)((TypeAnalysis *)TAR)->analyzeFunction(FTI).analyzer; } void *EnzymeGradientUtilsTypeAnalyzer(GradientUtils *G) { @@ -1184,6 +1184,7 @@ LLVMValueRef EnzymeCloneFunctionWithoutReturnOrArgs(LLVMValueRef FC, for (auto s : sub) { uint64_t ival; bool b = s.getAsInteger(10, ival); + (void)b; assert(!b); previdx.push_back(ival); } @@ -1241,6 +1242,7 @@ LLVMValueRef EnzymeComputeByteOffsetOfGEP(LLVMBuilderRef B_r, LLVMValueRef V_r, APInt Offset(width, 0); bool success = collectOffset(cast(gep), DL, width, VariableOffsets, Offset); + (void)success; assert(success); Value *start = ConstantInt::get(T, Offset); for (auto &pair : VariableOffsets) diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index d8e04724631a..fe3760457926 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -113,11 +113,14 @@ struct CFnTypeInfo { }; typedef enum { - DFT_OUT_DIFF = 0, // add differential to an output struct - DFT_DUP_ARG = 1, // duplicate the argument and store differential inside - DFT_CONSTANT = 2, // no differential + DFT_OUT_DIFF = 0, // add differential to an output struct. Only for scalar + // values in ReverseMode variants. + DFT_DUP_ARG = 1, // duplicate the argument and store differential inside. + // For references, pointers, or integers in ReverseMode + // variants. For all types in ForwardMode variants. + DFT_CONSTANT = 2, // no differential. Usable everywhere. DFT_DUP_NONEED = 3 // duplicate this argument and store differential inside, - // but don't need the forward + // but don't need the forward. Same as DUP_ARG otherwise. } CDIFFE_TYPE; typedef enum { BT_SCALAR = 0, BT_VECTOR = 1 } CBATCH_TYPE; diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 1cd6e84c5be1..b27e4beb08cd 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -37,6 +37,10 @@ add_public_tablegen_target(BlasDeclarationsIncGen) add_public_tablegen_target(BlasTAIncGen) add_public_tablegen_target(BlasDiffUseIncGen) +set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td) +enzyme_tablegen(IncludeUtils.inc -gen-header-strings) +add_public_tablegen_target(IncludeUtilsIncGen) + include_directories(${CMAKE_CURRENT_BINARY_DIR}) set(LLVM_LINK_COMPONENTS Demangle) @@ -74,6 +78,7 @@ if (${Clang_FOUND}) LLVM ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp @@ -107,6 +112,7 @@ if (${Clang_FOUND}) clang ) target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) +add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen) endif() add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR} ${ENZYME_SRC} Clang/EnzymePassLoader.cpp diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index 68c6e46784cf..cba31f314000 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -200,6 +200,7 @@ std::pair FindCanonicalIV(Loop *L, Type *Ty) { } llvm::errs() << *Header << "\n"; assert(0 && "Could not find canonical IV"); + return std::pair(nullptr, nullptr); } // Attempt to rewrite all phinode's in the loop in terms of the @@ -1330,8 +1331,10 @@ void CacheUtility::storeInstructionInCache(LimitContext ctx, IRBuilder<> &BuilderM, Value *val, AllocaInst *cache, MDNode *TBAA) { assert(BuilderM.GetInsertBlock()->getParent() == newFunc); +#ifndef NDEBUG if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == newFunc); +#endif IRBuilder<> v(BuilderM.GetInsertBlock()); v.SetInsertPoint(BuilderM.GetInsertBlock(), BuilderM.GetInsertPoint()); v.setFastMathFlags(getFast()); diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index efd802db406b..6280dc9f1b44 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -32,10 +32,8 @@ extern "C" { void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr; } -template -void AdjointGenerator::handleMPI(llvm::CallInst &call, - llvm::Function *called, - llvm::StringRef funcName) { +void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, + llvm::StringRef funcName) { using namespace llvm; assert(called); @@ -2214,8 +2212,7 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm_unreachable("Unhandled MPI FUNCTION"); } -template -bool AdjointGenerator::handleKnownCallDerivatives( +bool AdjointGenerator::handleKnownCallDerivatives( CallInst &call, Function *called, StringRef funcName, const std::vector &overwritten_args, CallInst *const newCall) { bool subretused = false; @@ -2254,11 +2251,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( } if (auto blas = extractBLAS(funcName)) { -#if LLVM_VERSION_MAJOR >= 16 - if (handleBLAS(call, called, blas.value(), overwritten_args)) -#else - if (handleBLAS(call, called, blas.getValue(), overwritten_args)) -#endif + if (handleBLAS(call, called, *blas, overwritten_args)) return true; } @@ -3235,7 +3228,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( llvm_unreachable("Unknown allocation to upgrade"); Size = gutils->getNewFromOriginal(Size); - if (auto CI = dyn_cast(Size)) { + if (isa(Size)) { B.SetInsertPoint(gutils->inversionAllocs); } Type *elTy = Type::getInt8Ty(call.getContext()); @@ -3330,7 +3323,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( // rematerialization is loop level. This is because one can have a // loop level cache, but a function level allocation (e.g. for stack // allocas). If we deleted it here, we would have no allocation! - auto AllocationLoop = gutils->OrigLI.getLoopFor(call.getParent()); + auto AllocationLoop = gutils->OrigLI->getLoopFor(call.getParent()); // An allocation within a loop, must definitionally be a loop level // allocation (but not always the other way around. if (AllocationLoop) @@ -3579,7 +3572,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( ConstantInt::getFalse(call.getContext())); return true; } - if (funcName == "memset" || funcName == "memset_pattern16") { + if (funcName == "memset" || funcName == "memset_pattern16" || + funcName == "__memset_chk") { visitMemSetCommon(call); return true; } @@ -4053,6 +4047,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( return true; } assert(!unnecessaryValues.count(rmat.first)); + (void)primalNeededInReverse; assert(primalNeededInReverse); } } @@ -4159,17 +4154,3 @@ bool AdjointGenerator::handleKnownCallDerivatives( return false; } - -template bool AdjointGenerator::handleKnownCallDerivatives( - CallInst &call, Function *called, StringRef funcName, - const std::vector &overwritten_args, CallInst *const newCall); -template bool -AdjointGenerator::handleKnownCallDerivatives( - CallInst &call, Function *called, StringRef funcName, - const std::vector &overwritten_args, CallInst *const newCall); - -template void -AdjointGenerator::handleMPI(CallInst &call, Function *called, - StringRef funcName); -template void AdjointGenerator::handleMPI( - CallInst &call, Function *called, StringRef funcName); diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index ed01f1bf5739..0072c958b517 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -25,16 +25,20 @@ #include "clang/AST/Attr.h" #include "clang/AST/DeclGroup.h" #include "clang/AST/RecursiveASTVisitor.h" +#include "clang/Basic/FileManager.h" #include "clang/Basic/MacroBuilder.h" #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/FrontendAction.h" #include "clang/Frontend/FrontendPluginRegistry.h" +#include "clang/Lex/HeaderSearch.h" #include "clang/Lex/PreprocessorOptions.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" #include "../Utils.h" +#include "IncludeUtils.inc" + using namespace clang; #if LLVM_VERSION_MAJOR >= 18 @@ -134,6 +138,39 @@ class EnzymePlugin final : public clang::ASTConsumer { Builder.defineMacro("ENZYME_VERSION_PATCH", std::to_string(ENZYME_VERSION_PATCH)); CI.getPreprocessor().setPredefines(Predefines.str()); + + auto baseFS = &CI.getFileManager().getVirtualFileSystem(); + llvm::vfs::OverlayFileSystem *fuseFS( + new llvm::vfs::OverlayFileSystem(baseFS)); + IntrusiveRefCntPtr fs( + new llvm::vfs::InMemoryFileSystem()); + + struct tm y2k = {}; + + y2k.tm_hour = 0; + y2k.tm_min = 0; + y2k.tm_sec = 0; + y2k.tm_year = 100; + y2k.tm_mon = 0; + y2k.tm_mday = 1; + time_t timer = mktime(&y2k); + for (const auto &pair : include_headers) { + fs->addFile(StringRef(pair[0]), timer, + llvm::MemoryBuffer::getMemBuffer( + StringRef(pair[1]), StringRef(pair[0]), + /*RequiresNullTerminator*/ true)); + } + + fuseFS->pushOverlay(fs); + fuseFS->pushOverlay(baseFS); + CI.getFileManager().setVirtualFileSystem(fuseFS); + + auto DE = CI.getFileManager().getDirectoryRef("/enzymeroot"); + assert(DE); + auto DL = DirectoryLookup(*DE, SrcMgr::C_User, + /*isFramework=*/false); + CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL, + /*isAngled=*/true); } ~EnzymePlugin() {} void HandleTranslationUnit(ASTContext &context) override {} @@ -141,10 +178,10 @@ class EnzymePlugin final : public clang::ASTConsumer { using namespace clang; DeclGroupRef::iterator it; - Visitor v(CI); + // Visitor v(CI); // Forcibly require emission of all libdevice for (it = dg.begin(); it != dg.end(); ++it) { - v.TraverseDecl(*it); + // v.TraverseDecl(*it); if (auto FD = dyn_cast(*it)) { if (!FD->hasAttr()) continue; diff --git a/enzyme/Enzyme/Clang/include_utils.td b/enzyme/Enzyme/Clang/include_utils.td new file mode 100644 index 000000000000..cb7cdd839c20 --- /dev/null +++ b/enzyme/Enzyme/Clang/include_utils.td @@ -0,0 +1,583 @@ +class Headers { + string filename = filename_; + string contents = contents_; +} + +def : Headers<"/enzymeroot/enzyme/utils", [{ +#pragma once + +extern int enzyme_dup; +extern int enzyme_dupnoneed; +extern int enzyme_out; +extern int enzyme_const; + +extern int enzyme_const_return; +extern int enzyme_active_return; +extern int enzyme_dup_return; + +extern int enzyme_primal_return; +extern int enzyme_noret; + +template +Return __enzyme_autodiff(T...); + +template +Return __enzyme_fwddiff(T...); + +#include + +namespace enzyme { + + struct nodiff{}; + + template + struct ReverseMode { + + }; + using Reverse = ReverseMode; + using ReverseWithPrimal = ReverseMode; + + template < typename T > + struct Active{ + T value; + Active(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct Duplicated{ + T value; + T shadow; + Duplicated(T &&v, T&& s) : value(v), shadow(s) {} + }; + + template < typename T > + struct Const{ + T value; + Const(T &&v) : value(v) {} + operator T&() { return value; } + }; + + template < typename T > + struct type_info { + static constexpr bool is_active = false; + using type = nodiff; + }; + + template < typename T > + struct type_info < Active >{ + static constexpr bool is_active = true; + using type = T; + }; + + template < typename ... T > + struct concatenated; + + template < typename ... S, typename T, typename ... rest > + struct concatenated < tuple < S ... >, T, rest ... > { + using type = typename concatenated< tuple< S ..., T>, rest ... >::type; + }; + + template < typename T > + struct concatenated < T > { + using type = T; + }; + + // Yikes! + // slightly cleaner in C++20, with std::remove_cvref + template < typename ... T > + struct autodiff_return; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type>; + }; + + template < typename RetType, typename ... T > + struct autodiff_return, RetType, T...> + { + using type = tuple< + typename type_info::type, + typename concatenated< tuple< >, + typename type_info< + typename remove_cvref< T >::type + >::type ... + >::type + >; + }; + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{enzyme_dup, arg.value, arg.shadow}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Active & arg) { + return enzyme::tuple{enzyme_out, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto expand_args(const enzyme::Const & arg) { + return enzyme::tuple{enzyme_const, arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Duplicated & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Active & arg) { + return enzyme::tuple{arg.value}; + } + + template < typename T > + __attribute__((always_inline)) + auto primal_args(const enzyme::Const & arg) { + return enzyme::tuple{arg.value}; + } + + namespace detail { + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(T &&t); + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple>{get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) push_return_last(tuple> &&t) { + return tuple{get<1>(t), get<0>(t)}; + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) rev_apply_impl(void* f, int* ret_attr, Tuple&& t, std::index_sequence) { + return push_return_last(__enzyme_autodiff(f, ret_attr, enzyme::get(impl::forward(t))...)); + } + + template + __attribute__((always_inline)) + constexpr decltype(auto) primal_apply_impl(function &&f, Tuple&& t, std::index_sequence) { + return f(enzyme::get(impl::forward(t))...); + } + + template < typename T > + struct default_ret_activity { + using type = Const; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template <> + struct default_ret_activity { + using type = Active; + }; + + template < typename T > + struct ret_global; + + template + struct ret_global> { + static constexpr int* value = &enzyme_const_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_active_return; + }; + + template + struct ret_global> { + static constexpr int* value = &enzyme_dup_return; + }; + + template + struct ret_used; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_primal_return; + }; + + template + struct ret_used, RetAct> { + static constexpr int* value = &enzyme_noret; + }; + + } // namespace detail + + + + template < typename return_type, typename function, typename ... enz_arg_types > + __attribute__((always_inline)) + auto primal_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::primal_apply_impl(f, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename function, typename ... arg_types> + auto primal_call(function && f, arg_types && ... args) { + return primal_impl(impl::forward(f), enzyme::tuple_cat(primal_args(args)...)); + } + + template < typename return_type, typename function, typename RetActivity, typename ... enz_arg_types > + __attribute__((always_inline)) + auto rev_autodiff_impl(function && f, enzyme::tuple< enz_arg_types ... > && arg_tup) { + using Tuple = enzyme::tuple< enz_arg_types ... >; + return detail::rev_apply_impl((void*)f, detail::ret_global::value, impl::forward(arg_tup), std::make_index_sequence>{}); + } + + template < typename DiffMode, typename RetActivity, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } + + template < typename DiffMode, typename function, typename ... arg_types> + __attribute__((always_inline)) + auto autodiff(function && f, arg_types && ... args) { + using primal_return_type = decltype(primal_call(impl::forward(f), impl::forward(args)...)); + using RetActivity = typename detail::default_ret_activity::type; + using return_type = typename autodiff_return::type; + return rev_autodiff_impl(impl::forward(f), enzyme::tuple_cat(enzyme::tuple{detail::ret_used::value}, expand_args(args)...)); + } +} +}]>; + +def : Headers<"/enzymeroot/enzyme/type_traits", [{ +#pragma once + +#include + +namespace enzyme { + +// this is already in C++20, but we reimplement it here for older C++ versions +template < typename T > +struct remove_cvref { + using type = + typename std::remove_reference< + typename std::remove_cv< + T + >::type + >::type; +}; + +template < typename T > +using remove_cvref_t = typename remove_cvref::type; + +namespace impl { + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>& __t) noexcept + { return static_cast<_Tp&&>(__t); } + + /** + * @brief Forward an rvalue. + * @return The parameter cast to the specified type. + * + * This function is used to implement "perfect forwarding". + */ + template + __attribute__((always_inline)) + constexpr _Tp&& + forward(std::remove_reference_t<_Tp>&& __t) noexcept + { + static_assert(!std::is_lvalue_reference<_Tp>::value, + "enzyme::impl::forward must not be used to convert an rvalue to an lvalue"); + return static_cast<_Tp&&>(__t); + } + +} + +} +}]>; + +def : Headers<"/enzymeroot/enzyme/tuple", [{ +#pragma once + +///////////// +// tuple.h // +///////////// + +// why reinvent the wheel and implement a tuple class? +// - ensure data is laid out in the same order the types are specified +// see: https://github.com/EnzymeAD/Enzyme/issues/1191#issuecomment-1556239213 +// - CUDA compatibility: std::tuple has some compatibility issues when used +// in a __device__ context (this may get better in c++20 with the improved +// constexpr support for std::tuple). Owning the implementation lets +// us add __host__ __device__ annotations to any part of it + +#include // for std::integer_sequence + +#include + +#define _NOEXCEPT noexcept +namespace enzyme { + +template +struct Index {}; + +template +struct value_at_position { + __attribute__((always_inline)) + T & operator[](Index) { return value; } + + __attribute__((always_inline)) + constexpr const T & operator[](Index) const { return value; } + T value; +}; + +template +struct tuple_base; + +template +struct tuple_base, T...> + : public value_at_position... { + using value_at_position::operator[]...; +}; + +template +struct tuple : public tuple_base, T...> {}; + +template +__attribute__((always_inline)) +tuple(T ...) -> tuple; + +template < int i, typename Tuple > +__attribute__((always_inline)) +decltype(auto) get(Tuple && tup) { + constexpr bool is_lvalue = std::is_lvalue_reference_v; + constexpr bool is_const = std::is_const_v>; + using T = remove_cvref_t< decltype(tup[Index{ } ]) >; + if constexpr ( is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr ( is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && is_const) { return static_cast(tup[Index{} ]); } + if constexpr (!is_lvalue && !is_const) { return static_cast(tup[Index{} ]); } +} + +template < int i, typename ... T> +__attribute__((always_inline)) +decltype(auto) get(const tuple< T ... > & tup) { + return tup[Index{} ]; +} + +template +struct tuple_size; + +template +struct tuple_size> : std::integral_constant {}; + +template +static constexpr size_t tuple_size_v = tuple_size::value; + +template +__attribute__((always_inline)) +constexpr auto forward_as_tuple(T&&... args) noexcept { + return tuple{impl::forward(args)...}; +} + +namespace impl { + +template +struct make_tuple_from_fwd_tuple; + +template +struct make_tuple_from_fwd_tuple> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd) { + return tuple{get(impl::forward(fwd))...}; + } +}; + +template +struct concat_with_fwd_tuple; + +template < typename Tuple > +using iseq = std::make_index_sequence > >; + +template +struct concat_with_fwd_tuple, std::index_sequence> { + template + __attribute__((always_inline)) + static constexpr auto f(FWD_TUPLE&& fwd, TUPLE&& t) { + return forward_as_tuple(get(impl::forward(fwd))..., get(impl::forward(t))...); + } +}; + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(Tuple&& ret) { + return make_tuple_from_fwd_tuple< iseq< Tuple > >::f(impl::forward< Tuple >(ret)); +} + +template +__attribute__((always_inline)) +static constexpr auto tuple_cat(FWD_TUPLE&& fwd, first&& t, rest&&... ts) { + return tuple_cat(concat_with_fwd_tuple< iseq, iseq >::f(impl::forward(fwd), impl::forward(t)), impl::forward(ts)...); +} + +} // namespace impl + +template +__attribute__((always_inline)) +constexpr auto tuple_cat(Tuples&&... tuples) { + return impl::tuple_cat(impl::forward(tuples)...); +} + +} // namespace enzyme +#undef _NOEXCEPT +}]>; + +def : Headers<"/enzymeroot/enzyme/enzyme", [{ +#ifdef __cplusplus +#include "enzyme/utils" +#else +#warning "Enzyme wrapper templates only available in C++" +#endif +}]>; + +def : Headers<"/enzymeroot/enzyme/mpfr", [{ +//===- EnzymeMPFR.h - MPFR wrappers ---------------------------------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// TODO s +// +// (for MPFR ver. 2.1) +// +// We need to set the range of the allowed exponent using `mpfr_set_emin` and +// `mpfr_set_emax`. (This means we can also play with whether the range is +// centered around 0 (1?) or somewhere else) +// +// (also these need to be mutex'ed as the exponent change is global in mpfr and +// not float-specific) ... (mpfr seems to have thread safe mode - check if it is +// enabled or if it is enabled by default) +// +// For that we need to do this check: +// If the user changes the exponent range, it is her/his responsibility to +// check that all current floating-point variables are in the new allowed +// range (for example using mpfr_check_range), otherwise the subsequent +// behavior will be undefined, in the sense of the ISO C standard. +// +// MPFR docs state the following: +// Note: Overflow handling is still experimental and currently implemented +// partially. If an overflow occurs internally at the wrong place, anything +// can happen (crash, wrong results, etc). +// +// Which we would like to avoid somehow. +// +// MPFR also has this limitation that we need to address for accurate +// simulation: +// [...] subnormal numbers are not implemented. +// + +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_BINOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __attribute__((weak)) \ + RET __enzyme_mpfr_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG1(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } + +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN +#define __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_BINOP(binop, LLVM_OP_NAME, MPFR_FUNC_NAME, 64_52, double, d, \ + double, d, double, d, ROUNDING_MODE) +#define __ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(LLVM_OP_NAME, \ + MPFR_FUNC_NAME) \ + __ENZYME_MPFR_DOUBLE_BINOP(LLVM_OP_NAME, MPFR_FUNC_NAME, \ + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fmul, mul) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fadd, add) +__ENZYME_MPFR_DOUBLE_BINOP_DEFAULT_ROUNDING(fdiv, div) + +__ENZYME_MPFR_SINGOP(func, sqrt, sqrt, 64_52, double, d, double, d, + __ENZYME_MPFR_DEFAULT_ROUNDING_MODE) + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +}]>; diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 68a7e02d048d..552d2894d759 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -164,10 +164,12 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( AllocaInst *DiffeGradientUtils::getDifferential(Value *val) { assert(val); +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif assert(inversionAllocs); Type *type = getShadowType(val->getType()); @@ -195,10 +197,12 @@ AllocaInst *DiffeGradientUtils::getDifferential(Value *val) { } Value *DiffeGradientUtils::diffe(Value *val, IRBuilder<> &BuilderM) { +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif if (isConstantValue(val)) { llvm::errs() << *newFunc << "\n"; @@ -336,6 +340,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, llvm::errs() << "} start=" << start << " size=" << size << " storeSize=" << storeSize << " val=" << *val << "\n"; assert(0 && "unhandled accumulate with partial sizes"); + return {}; } SmallVector @@ -345,10 +350,12 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, assert(mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ReverseModeCombined); +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) assert(inst->getParent()->getParent() == oldFunc); +#endif SmallVector addedSelects; @@ -659,6 +666,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM, void DiffeGradientUtils::setDiffe(Value *val, Value *toset, IRBuilder<> &BuilderM) { +#ifndef NDEBUG if (auto arg = dyn_cast(val)) assert(arg->getParent() == oldFunc); if (auto inst = dyn_cast(val)) @@ -668,6 +676,7 @@ void DiffeGradientUtils::setDiffe(Value *val, Value *toset, llvm::errs() << *val << "\n"; } assert(!isConstantValue(val)); +#endif toset = SanitizeDerivatives(val, toset, BuilderM); if (mode == DerivativeMode::ForwardMode || mode == DerivativeMode::ForwardModeSplit) { @@ -928,11 +937,20 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, applyChainRule(PointerType::get(addingType, 1), BuilderM, rule, ptr); } - assert(!mask); if (mask) { - llvm::errs() << "unhandled masked atomic fadd on llvm version " << *ptr - << " " << *dif << " mask: " << *mask << "\n"; - llvm_unreachable("unhandled masked atomic fadd"); + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Unimplemented masked atomic fadd for ptr:" << *ptr + << " dif:" << *dif << " mask: " << *mask << " orig: " << *orig << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(orig), + ErrorType::NoDerivative, this, nullptr, + wrap(&BuilderM)); + return; + } else { + EmitFailure("NoDerivative", orig->getDebugLoc(), orig, ss.str()); + return; + } } /* @@ -966,14 +984,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, if (alignv) { if (start != 0) { // todo make better alignment calculation -#if LLVM_VERSION_MAJOR >= 16 - assert(alignv.value().value() != 0); - if (start % alignv.value().value() != 0) -#else - assert(alignv.getValue().value() != 0); - if (start % alignv.getValue().value() != 0) -#endif - { + assert((*alignv).value() != 0); + if (start % (*alignv).value() != 0) { alignv = Align(1); } } @@ -1007,13 +1019,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, if (alignv) { if (start != 0) { // todo make better alignment calculation -#if LLVM_VERSION_MAJOR >= 16 - assert(alignv.value().value() != 0); - if (start % alignv.value().value() != 0) { -#else - assert(alignv.getValue().value() != 0); - if (start % alignv.getValue().value() != 0) { -#endif + assert((*alignv).value() != 0); + if (start % (*alignv).value() != 0) { alignv = Align(1); } } @@ -1093,11 +1100,7 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, st->setDebugLoc(getNewFromOriginal(orig->getDebugLoc())); if (align) { -#if LLVM_VERSION_MAJOR >= 16 - auto alignv = align ? align.value().value() : 0; -#else - auto alignv = align ? align.getValue().value() : 0; -#endif + auto alignv = align ? (*align).value() : 0; if (alignv != 0) { if (start != 0) { // todo make better alignment calculation diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.cpp b/enzyme/Enzyme/DifferentialUseAnalysis.cpp index 4a2cff5d79a9..b0b8c48ab5c9 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.cpp +++ b/enzyme/Enzyme/DifferentialUseAnalysis.cpp @@ -54,9 +54,11 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( const SmallPtrSetImpl &oldUnreachable, QueryType qtype, bool *recursiveUse) { TypeResults const &TR = gutils->TR; +#ifndef NDEBUG if (auto ainst = dyn_cast(val)) { assert(ainst->getParent()->getParent() == gutils->oldFunc); } +#endif bool shadow = qtype == QueryType::Shadow || qtype == QueryType::ShadowByConstPrimal; @@ -79,8 +81,7 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse( if (!user) { if (EnzymePrintDiffUse) - llvm::errs() << " Need: of " << *val << " in reverse as unknown user " - << *user << "\n"; + llvm::errs() << " Need: of " << *val << " in reverse as nullptr user\n"; return true; } @@ -759,33 +760,26 @@ int DifferentialUseAnalysis::cmpLoopNest(Loop *prev, Loop *next) { return -1; } -void DifferentialUseAnalysis::minCut( - const DataLayout &DL, LoopInfo &OrigLI, - const SetVector &Recomputes, - const SetVector &Intermediates, SetVector &Required, - SetVector &MinReq, - const ValueMap - &rematerializableAllocations, - llvm::TargetLibraryInfo &TLI) { +void DifferentialUseAnalysis::minCut(const DataLayout &DL, LoopInfo &OrigLI, + const SetVector &Recomputes, + const SetVector &Intermediates, + SetVector &Required, + SetVector &MinReq, + const GradientUtils *gutils, + llvm::TargetLibraryInfo &TLI) { Graph G; for (auto V : Intermediates) { G[Node(V, false)].insert(Node(V, true)); - for (auto U : V->users()) { - if (auto I = dyn_cast(U)) { - for (auto pair : rematerializableAllocations) { - if (Intermediates.count(pair.first) && pair.second.stores.count(I)) { - if (V != pair.first) - G[Node(V, true)].insert(Node(pair.first, false)); + forEachDifferentialUser( + [&](Value *U) { + if (Intermediates.count(U)) { + if (V != U) + G[Node(V, true)].insert(Node(U, false)); } - } - } - if (Intermediates.count(U)) { - if (V != U) - G[Node(V, true)].insert(Node(U, false)); - } - } + }, + gutils, V); } - for (auto pair : rematerializableAllocations) { + for (auto pair : gutils->rematerializableAllocations) { if (Intermediates.count(pair.first)) { for (LoadInst *L : pair.second.loads) { if (Intermediates.count(L)) { @@ -801,12 +795,14 @@ void DifferentialUseAnalysis::minCut( } } } +#ifndef NDEBUG for (auto R : Required) { assert(Intermediates.count(R)); } for (auto R : Recomputes) { assert(Intermediates.count(R)); } +#endif Graph Orig = G; @@ -856,7 +852,9 @@ void DifferentialUseAnalysis::minCut( assert(pair.first.outgoing == 0 && N.outgoing == 1); assert(pair.first.V == N.V); MinReq.insert(N.V); - todo.push_back(N.V); + if (Orig.find(Node(N.V, true)) != Orig.end()) { + todo.push_back(N.V); + } } } } @@ -867,20 +865,20 @@ void DifferentialUseAnalysis::minCut( auto V = todo.front(); todo.pop_front(); auto found = Orig.find(Node(V, true)); - if (found->second.size() == 1 && !Required.count(V)) { + assert(found != Orig.end()); + const auto &mp = found->second; + if (mp.size() == 1 && !Required.count(V)) { bool potentiallyRecursive = - isa((*found->second.begin()).V) && - OrigLI.isLoopHeader( - cast((*found->second.begin()).V)->getParent()); + isa((*mp.begin()).V) && + OrigLI.isLoopHeader(cast((*mp.begin()).V)->getParent()); int moreOuterLoop = cmpLoopNest( OrigLI.getLoopFor(cast(V)->getParent()), - OrigLI.getLoopFor( - cast(((*found->second.begin()).V))->getParent())); + OrigLI.getLoopFor(cast(((*mp.begin()).V))->getParent())); if (potentiallyRecursive) continue; if (moreOuterLoop == -1) continue; - if (auto ASC = dyn_cast((*found->second.begin()).V)) { + if (auto ASC = dyn_cast((*mp.begin()).V)) { if (ASC->getDestAddressSpace() == 11 || ASC->getDestAddressSpace() == 13) continue; @@ -888,8 +886,8 @@ void DifferentialUseAnalysis::minCut( continue; } // If an allocation call, we cannot cache any "capturing" users - if (isAllocationCall(V, TLI)) { - auto next = (*found->second.begin()).V; + if (isAllocationCall(V, TLI) || isa(V)) { + auto next = (*mp.begin()).V; bool noncapture = false; if (isa(next)) { noncapture = true; @@ -925,10 +923,12 @@ void DifferentialUseAnalysis::minCut( if (moreOuterLoop == 1 || (moreOuterLoop == 0 && DL.getTypeSizeInBits(V->getType()) >= - DL.getTypeSizeInBits((*found->second.begin()).V->getType()))) { + DL.getTypeSizeInBits((*mp.begin()).V->getType()))) { MinReq.remove(V); - MinReq.insert((*found->second.begin()).V); - todo.push_back((*found->second.begin()).V); + auto nnode = (*mp.begin()).V; + MinReq.insert(nnode); + if (Orig.find(Node(nnode, true)) != Orig.end()) + todo.push_back(nnode); } } } diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index d9d689eeee64..565e18724869 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -166,6 +166,60 @@ inline bool is_value_needed_in_reverse( } } + if (!TR.anyFloat(const_cast(inst))) + if (auto IVI = dyn_cast(user)) { + bool inserted = false; + if (auto II = dyn_cast(IVI)) + inserted = II->getInsertedValueOperand() == inst || + II->getAggregateOperand() == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getAggregateOperand() == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(1) == inst || II->getOperand(0) == inst; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(0) == inst; + if (inserted) { + SmallVector todo; + todo.push_back(IVI); + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + if (auto IVI2 = dyn_cast(u)) { + todo.push_back(IVI2); + continue; + } + + bool partial = false; + if (!gutils->isConstantValue(const_cast(cur))) { + partial = is_value_needed_in_reverse( + gutils, user, mode, seen, oldUnreachable); + } + if (partial) { + + if (EnzymePrintDiffUse) + llvm::errs() + << " Need (partial) direct " << to_string(VT) << " of " + << *inst << " in reverse from insertelem " << *user + << " via " << *cur << " in " << *u << "\n"; + return seen[idx] = true; + } + } + } + } + } + if (VT != QueryType::Primal) continue; } @@ -332,36 +386,14 @@ inline bool is_value_needed_in_reverse( primalUsedInShadowPointer = false; } } - if (auto IVI = dyn_cast(user)) { - bool valueIsIndex = false; - for (unsigned i = 2; i < IVI->getNumOperands(); ++i) { - if (IVI->getOperand(i) == inst) { - if (inst == IVI->getInsertedValueOperand() && - TR.query( - const_cast(IVI->getInsertedValueOperand()))[{-1}] - .isFloat()) { - continue; - } - valueIsIndex = true; - } - } - primalUsedInShadowPointer = valueIsIndex; - } - if (auto EVI = dyn_cast(user)) { - bool valueIsIndex = false; - for (unsigned i = 1; i < EVI->getNumOperands(); ++i) { - if (EVI->getOperand(i) == inst) { - valueIsIndex = true; - } - } - primalUsedInShadowPointer = valueIsIndex; - } + // No need for insert/extractvalue since indices are unsigned + // not llvm runtime values + if (isa(user) || isa(user)) + primalUsedInShadowPointer = false; if (primalUsedInShadowPointer) if (!user->getType()->isVoidTy() && - TR.query(const_cast(user)) - .Inner0() - .isPossiblePointer()) { + TR.anyPointer(const_cast(user))) { if (is_value_needed_in_reverse( gutils, user, mode, seen, oldUnreachable)) { if (EnzymePrintDiffUse) @@ -433,11 +465,66 @@ void minCut(const llvm::DataLayout &DL, llvm::LoopInfo &OrigLI, const llvm::SetVector &Recomputes, const llvm::SetVector &Intermediates, llvm::SetVector &Required, - llvm::SetVector &MinReq, - const llvm::ValueMap - &rematerializableAllocations, + llvm::SetVector &MinReq, const GradientUtils *gutils, llvm::TargetLibraryInfo &TLI); +__attribute__((always_inline)) static inline void +forEachDirectInsertUser(llvm::function_ref f, + const GradientUtils *gutils, llvm::Instruction *IVI, + llvm::Value *val, bool useCheck) { + using namespace llvm; + if (!gutils->isConstantValue(IVI)) + return; + bool inserted = false; + if (auto II = dyn_cast(IVI)) + inserted = II->getInsertedValueOperand() == val || + II->getAggregateOperand() == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getAggregateOperand() == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(1) == val || II->getOperand(0) == val; + if (auto II = dyn_cast(IVI)) + inserted = II->getOperand(0) == val; + if (inserted) { + SmallVector todo; + todo.push_back(IVI); + while (todo.size()) { + auto cur = todo.pop_back_val(); + for (auto u : cur->users()) { + if (isa(u) || isa(u) || + isa(u) || isa(u)) { + auto I2 = cast(u); + bool subCheck = useCheck; + if (!subCheck) { + subCheck = is_value_needed_in_reverse( + gutils, I2, gutils->mode, gutils->notForAnalysis); + } + if (subCheck) + f(I2); + todo.push_back(I2); + continue; + } + } + } + } +} + +__attribute__((always_inline)) static inline void +forEachDifferentialUser(llvm::function_ref f, + const GradientUtils *gutils, llvm::Value *V, + bool useCheck = false) { + for (auto V2 : V->users()) { + if (auto Inst = llvm::dyn_cast(V2)) { + for (const auto &pair : gutils->rematerializableAllocations) { + if (pair.second.stores.count(Inst)) { + f(llvm::cast(pair.first)); + } + } + f(Inst); + forEachDirectInsertUser(f, gutils, Inst, V, useCheck); + } + } +} }; // namespace DifferentialUseAnalysis #endif diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index e402c1656248..4ab84bd6e67e 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -37,10 +37,9 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" +#include #if LLVM_VERSION_MAJOR <= 16 #include "llvm/ADT/Optional.h" -#else -#include #endif #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -59,6 +58,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -113,6 +113,12 @@ llvm::cl::opt EnzymeAttributor("enzyme-attributor", cl::init(false), llvm::cl::opt EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden, cl::desc("Whether to enable openmp opt")); +llvm::cl::opt EnzymeTruncateAll( + "enzyme-truncate-all", cl::init(""), cl::Hidden, + cl::desc( + "Truncate all floating point operations. " + "E.g. \"64to32\" or \"64to-\".")); + #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex #define getAttribute getAttributeAtIndex @@ -430,6 +436,9 @@ std::optional getMetadataName(llvm::Value *res) Optional getMetadataName(llvm::Value *res) #endif { + if (auto S = simplifyLoad(res)) + return getMetadataName(S); + if (auto av = dyn_cast(res)) { return cast(av->getMetadata())->getString(); } else if ((isa(res) || isa(res)) && @@ -458,12 +467,11 @@ Optional getMetadataName(llvm::Value *res) return gv->getName(); } else if (auto gv = dyn_cast(res)) { return gv->getName(); - } else { - if (isa(res)) { - return recursePhiReads(cast(res)); - } - return {}; + } else if (isa(res)) { + return recursePhiReads(cast(res)); } + + return {}; } static Value *adaptReturnedVector(Value *ret, Value *diffret, @@ -1314,25 +1322,58 @@ class EnzymeBase { return type_args; } - bool HandleTruncateFunc(CallInst *CI) { + static FloatRepresentation getDefaultFloatRepr(unsigned width) { + switch (width) { + case 16: + return FloatRepresentation(5, 10); + case 32: + return FloatRepresentation(8, 23); + case 64: + return FloatRepresentation(11, 52); + default: + llvm_unreachable("Invalid float width"); + } + }; + + bool HandleTruncateFunc(CallInst *CI, TruncateMode mode) { IRBuilder<> Builder(CI); Function *F = parseFunctionParameter(CI); if (!F) return false; - if (CI->arg_size() != 3) { + unsigned ArgSize = CI->arg_size(); + if (ArgSize != 4 && ArgSize != 3) { EmitFailure("TooManyArgs", CI->getDebugLoc(), CI, "Had incorrect number of args to __enzyme_truncate_func", *CI, - " - expected 3"); + " - expected 3 or 4"); return false; } - auto Cfrom = cast(CI->getArgOperand(1)); - assert(Cfrom); - auto Cto = cast(CI->getArgOperand(2)); - assert(Cto); + FloatTruncation truncation = [&]() -> FloatTruncation { + if (ArgSize == 3) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto = cast(CI->getArgOperand(2)); + assert(Cto); + return FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue())); + } else if (ArgSize == 4) { + auto Cfrom = cast(CI->getArgOperand(1)); + assert(Cfrom); + auto Cto_exponent = cast(CI->getArgOperand(2)); + assert(Cto_exponent); + auto Cto_significand = cast(CI->getArgOperand(3)); + assert(Cto_significand); + return FloatTruncation( + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + FloatRepresentation( + (unsigned)Cto_exponent->getValue().getZExtValue(), + (unsigned)Cto_significand->getValue().getZExtValue())); + } + llvm_unreachable("??"); + }(); + RequestContext context(CI, &Builder); - llvm::Value *res = Logic.CreateTruncateFunc( - context, F, (unsigned)Cfrom->getValue().getZExtValue(), - (unsigned)Cto->getValue().getZExtValue()); + llvm::Value *res = Logic.CreateTruncateFunc(context, F, truncation, mode); if (!res) return false; res = Builder.CreatePointerCast(res, CI->getType()); @@ -1356,8 +1397,10 @@ class EnzymeBase { auto Addr = CI->getArgOperand(0); RequestContext context(CI, &Builder); bool res = Logic.CreateTruncateValue( - context, Addr, (unsigned)Cfrom->getValue().getZExtValue(), - (unsigned)Cto->getValue().getZExtValue(), isTruncate); + context, Addr, + getDefaultFloatRepr((unsigned)Cfrom->getValue().getZExtValue()), + getDefaultFloatRepr((unsigned)Cto->getValue().getZExtValue()), + isTruncate); if (!res) return false; return true; @@ -1887,15 +1930,9 @@ class EnzymeBase { #endif } -#if LLVM_VERSION_MAJOR >= 16 return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.value(), sizeOnly, + byVal, constants, fn, mode, *options, sizeOnly, calls); -#else - return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, - byVal, constants, fn, mode, options.getValue(), - sizeOnly, calls); -#endif } bool HandleProbProg(CallInst *CI, ProbProgMode mode, @@ -2025,23 +2062,86 @@ class EnzymeBase { #endif } -#if LLVM_VERSION_MAJOR >= 16 - bool status = - HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, - constants, newFunc, DerivativeMode::ReverseModeCombined, - opt.value(), false, calls); -#else - bool status = - HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, - constants, newFunc, DerivativeMode::ReverseModeCombined, - opt.getValue(), false, calls); -#endif + bool status = HandleAutoDiff( + CI, CI->getCallingConv(), ret, retElemType, dargs, byVal, constants, + newFunc, DerivativeMode::ReverseModeCombined, *opt, false, calls); delete interface; return status; } + bool handleFullModuleTrunc(Function &F) { + typedef std::vector TruncationsTy; + static TruncationsTy FullModuleTruncs = []() -> TruncationsTy { + StringRef ConfigStr(EnzymeTruncateAll); + auto Invalid = [=]() { + // TODO emit better diagnostic + llvm::report_fatal_error("error: invalid format for truncation config"); + }; + + // "64" or "11-52" + auto parseFloatRepr = [&]() -> std::optional { + unsigned Tmp = 0; + if (ConfigStr.consumeInteger(10, Tmp)) + return {}; + if (ConfigStr.consume_front("-")) { + unsigned Tmp2 = 0; + if (ConfigStr.consumeInteger(10, Tmp2)) + Invalid(); + return FloatRepresentation(Tmp, Tmp2); + } + return getDefaultFloatRepr(Tmp); + }; + + // Parse "64to32;32to16;5-10to4-9" + TruncationsTy Tmp; + while (true) { + auto From = parseFloatRepr(); + if (!From && !ConfigStr.empty()) + Invalid(); + if (!From) + break; + if (!ConfigStr.consume_front("to")) + Invalid(); + auto To = parseFloatRepr(); + if (!To) + Invalid(); + Tmp.push_back({*From, *To}); + ConfigStr.consume_front(";"); + } + return Tmp; + }(); + + if (FullModuleTruncs.empty()) + return false; + + // TODO sort truncations (64to32, then 32to16 will make everything 16) + for (auto Truncation : FullModuleTruncs) { + IRBuilder<> Builder(F.getContext()); + RequestContext context(&*F.getEntryBlock().begin(), &Builder); + Function *TruncatedFunc = Logic.CreateTruncateFunc( + context, &F, Truncation, TruncOpFullModuleMode); + + ValueToValueMapTy Mapping; + for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args())) + Mapping[&TArg] = &Arg; + + // Move the truncated body into the original function + F.deleteBody(); +#if LLVM_VERSION_MAJOR >= 16 + F.splice(F.begin(), TruncatedFunc); +#else + F.getBasicBlockList().splice(F.begin(), + TruncatedFunc->getBasicBlockList()); +#endif + RemapFunction(F, Mapping, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + TruncatedFunc->deleteBody(); + } + return true; + } + bool lowerEnzymeCalls(Function &F, std::set &done) { if (done.count(&F)) return false; @@ -2050,6 +2150,9 @@ class EnzymeBase { if (F.empty()) return false; + if (handleFullModuleTrunc(F)) + return true; + bool Changed = false; for (BasicBlock &BB : F) @@ -2110,7 +2213,8 @@ class EnzymeBase { MapVector toVirtual; MapVector toSize; SmallVector toBatch; - SmallVector toTruncateFunc; + SmallVector toTruncateFuncMem; + SmallVector toTruncateFuncOp; SmallVector toTruncateValue; SmallVector toExpandValue; MapVector toProbProg; @@ -2422,7 +2526,8 @@ class EnzymeBase { bool virtualCall = false; bool sizeOnly = false; bool batch = false; - bool truncateFunc = false; + bool truncateFuncOp = false; + bool truncateFuncMem = false; bool truncateValue = false; bool expandValue = false; bool probProg = false; @@ -2454,13 +2559,16 @@ class EnzymeBase { } else if (Fn->getName().contains("__enzyme_batch")) { enableEnzyme = true; batch = true; - } else if (Fn->getName().contains("__enzyme_truncate_func")) { + } else if (Fn->getName().contains("__enzyme_truncate_mem_func")) { + enableEnzyme = true; + truncateFuncMem = true; + } else if (Fn->getName().contains("__enzyme_truncate_op_func")) { enableEnzyme = true; - truncateFunc = true; - } else if (Fn->getName().contains("__enzyme_truncate_value")) { + truncateFuncOp = true; + } else if (Fn->getName().contains("__enzyme_truncate_mem_value")) { enableEnzyme = true; truncateValue = true; - } else if (Fn->getName().contains("__enzyme_expand_value")) { + } else if (Fn->getName().contains("__enzyme_expand_mem_value")) { enableEnzyme = true; expandValue = true; } else if (Fn->getName().contains("__enzyme_likelihood")) { @@ -2520,8 +2628,10 @@ class EnzymeBase { toSize[CI] = derivativeMode; else if (batch) toBatch.push_back(CI); - else if (truncateFunc) - toTruncateFunc.push_back(CI); + else if (truncateFuncOp) + toTruncateFuncOp.push_back(CI); + else if (truncateFuncMem) + toTruncateFuncMem.push_back(CI); else if (truncateValue) toTruncateValue.push_back(CI); else if (expandValue) @@ -2619,8 +2729,11 @@ class EnzymeBase { for (auto call : toBatch) { HandleBatch(call); } - for (auto call : toTruncateFunc) { - HandleTruncateFunc(call); + for (auto call : toTruncateFuncMem) { + HandleTruncateFunc(call, TruncMemMode); + } + for (auto call : toTruncateFuncOp) { + HandleTruncateFunc(call, TruncOpMode); } for (auto call : toTruncateValue) { HandleTruncateValue(call, true); @@ -3103,6 +3216,7 @@ AnalysisKey EnzymeNewPM::Key; #include "PreserveNVVM.h" #include "TypeAnalysis/TypeAnalysisPrinter.h" #include "llvm/Passes/PassBuilder.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" #if LLVM_VERSION_MAJOR >= 15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "llvm/Transforms/IPO/CalledValuePropagation.h" @@ -3333,6 +3447,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { #else prePass(MPM); #endif + MPM.addPass(llvm::AlwaysInlinerPass()); FunctionPassManager OptimizerPM; FunctionPassManager OptimizerPM2; #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 9ad444e4bd3a..f8fdf3b3124a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -27,9 +27,18 @@ // primal pass. // //===----------------------------------------------------------------------===// +#include "EnzymeLogic.h" #include "ActivityAnalysis.h" #include "AdjointGenerator.h" +#include "EnzymeLogic.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/ErrorHandling.h" +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -57,6 +66,8 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Demangle/Demangle.h" + #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -77,6 +88,7 @@ #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex +#define getAttribute getAttributeAtIndex #define removeAttribute removeAttributeAtIndex #endif @@ -976,7 +988,7 @@ void calculateUnusedValuesInFunction( if (newMemory) { bool foundStore = false; allInstructionsBetween( - gutils->OrigLI, cast(at), + *gutils->OrigLI, cast(at), const_cast(mti), [&](Instruction *I) -> bool { if (!I->mayWriteToMemory()) @@ -989,7 +1001,7 @@ void calculateUnusedValuesInFunction( } if (writesToMemoryReadBy( - gutils->OrigAA, TLI, + *gutils->OrigAA, TLI, /*maybeReader*/ const_cast(mti), /*maybeWriter*/ I)) { foundStore = true; @@ -1138,7 +1150,7 @@ void calculateUnusedStoresInFunction( if (newMemory) { bool foundStore = false; allInstructionsBetween( - gutils->OrigLI, cast(at), + *gutils->OrigLI, cast(at), const_cast(mti), [&](Instruction *I) -> bool { if (!I->mayWriteToMemory()) return /*earlyBreak*/ false; @@ -1147,7 +1159,7 @@ void calculateUnusedStoresInFunction( // if (I == &MTI) return; if (writesToMemoryReadBy( - gutils->OrigAA, TLI, + *gutils->OrigAA, TLI, /*maybeReader*/ const_cast(mti), /*maybeWriter*/ I)) { foundStore = true; @@ -1277,7 +1289,7 @@ bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils) { } if (!op->getType()->isFPOrFPVectorTy() && !gutils->isConstantValue(op) && - gutils->TR.query(op).Inner0().isPossiblePointer()) { + gutils->TR.anyPointer(op)) { modifyPrimal = true; #ifdef PRINT_AUGCALL @@ -1310,7 +1322,7 @@ bool shouldAugmentCall(CallInst *op, const GradientUtils *gutils) { if (!argType->isFPOrFPVectorTy() && !gutils->isConstantValue(op->getArgOperand(i)) && - gutils->TR.query(op->getArgOperand(i)).Inner0().isPossiblePointer()) { + gutils->TR.anyPointer(op->getArgOperand(i))) { if (!isReadOnly(op, i)) { modifyPrimal = true; #ifdef PRINT_AUGCALL @@ -1547,7 +1559,7 @@ bool legalCombinedForwardReverse( auto consider = [&](Instruction *user) { if (!user->mayReadFromMemory()) return false; - if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI, + if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI, /*maybeReader*/ user, /*maybeWriter*/ inst)) { @@ -1580,7 +1592,7 @@ bool legalCombinedForwardReverse( if (!post->mayWriteToMemory()) return false; - if (writesToMemoryReadBy(gutils->OrigAA, gutils->TLI, + if (writesToMemoryReadBy(*gutils->OrigAA, gutils->TLI, /*maybeReader*/ inst, /*maybeWriter*/ post)) { if (EnzymePrintPerf) { @@ -1956,11 +1968,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( assert(!todiff->getReturnType()->isEmptyTy() && !todiff->getReturnType()->isVoidTy()); - assert(_overwritten_args.size() == todiff->arg_size()); - FnTypeInfo oldTypeInfo = preventTypeAnalysisLoops(oldTypeInfo_, todiff); - - assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); AugmentedCacheKey tup = {todiff, retType, constant_args, _overwritten_args, returnUsed, shadowReturnUsed, @@ -1968,6 +1976,51 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( AtomicAdd, omp, width}; + if (_overwritten_args.size() != todiff->arg_size()) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << " overwritten_args.size() [" << _overwritten_args.size() + << "] != todiff->arg_size()\n"; + ss << "todiff: " << *todiff << "\n"; + llvm::Value *toshow = todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *todiff << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(todiff), + wrap(context.ip)); + auto newFunc = todiff; + std::map returnMapping; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; + } + if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + auto newFunc = todiff; + std::map returnMapping; + return insert_or_assign( + AugmentedCachedFunctions, tup, + AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, + constant_args)) + ->second; + } + llvm::errs() << "mod: " << *todiff->getParent() << "\n"; + llvm::errs() << *todiff << "\n"; + llvm_unreachable( + "attempting to differentiate function with wrong overwritten count"); + } + + assert(_overwritten_args.size() == todiff->arg_size()); + assert(constant_args.size() == todiff->getFunctionType()->getNumParams()); + auto found = AugmentedCachedFunctions.find(tup); if (found != AugmentedCachedFunctions.end()) { return found->second; @@ -2290,7 +2343,18 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( if (todiff->empty()) { std::string s; llvm::raw_string_ostream ss(s); - ss << "No augmented forward pass found for " + todiff->getName() << "\n"; + ss << "No augmented forward pass found for " + todiff->getName(); + { + std::string demangledName = llvm::demangle(todiff->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); + } + if (demangledName != todiff->getName()) + ss << "(" << demangledName << ")"; + } + ss << "\n"; llvm::Value *toshow = todiff; if (context.req) { toshow = context.req; @@ -2341,9 +2405,9 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, gutils->rematerializableAllocations, gutils->TR, - gutils->OrigAA, gutils->oldFunc, + *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, DerivativeMode::ReverseModePrimal, omp); const std::map> overwritten_args_map = CA.compute_overwritten_args_for_callsites(); @@ -2413,12 +2477,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( } } - AdjointGenerator maker( - DerivativeMode::ReverseModePrimal, gutils, constant_args, retType, - getIndex, overwritten_args_map, &returnuses, - &AugmentedCachedFunctions.find(tup)->second, nullptr, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, - nullptr); + AdjointGenerator maker(DerivativeMode::ReverseModePrimal, gutils, + constant_args, retType, getIndex, overwritten_args_map, + &returnuses, + &AugmentedCachedFunctions.find(tup)->second, nullptr, + unnecessaryValues, unnecessaryInstructions, + unnecessaryStores, guaranteedUnreachable, nullptr); for (BasicBlock &oBB : *gutils->oldFunc) { auto term = oBB.getTerminator(); @@ -2502,7 +2566,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( IRBuilder<> BuilderZ(newri); Value *invertri = nullptr; if (gutils->isConstantValue(orig_oldval)) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && gutils->TR.query(orig_oldval)[{-1}].isPossiblePointer()) { if (!isa(orig_oldval) && !isa(orig_oldval)) { @@ -2510,9 +2574,12 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( raw_string_ostream ss(str); ss << "Mismatched activity for: " << *ri << " const val: " << *orig_oldval; - invertri = unwrap(CustomErrorHandler( - str.c_str(), wrap(ri), ErrorType::MixedActivityError, - gutils, wrap(orig_oldval), wrap(&BuilderZ))); + if (CustomErrorHandler) + invertri = unwrap(CustomErrorHandler( + str.c_str(), wrap(ri), ErrorType::MixedActivityError, + gutils, wrap(orig_oldval), wrap(&BuilderZ))); + else + EmitWarning("MixedActivityError", *ri, ss.str()); } } } @@ -2724,7 +2791,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( NewF->addParamAttr(attrIndex, Attribute::NoAlias); } for (auto name : {"enzyme_sret", "enzyme_sret_v", "enzymejl_returnRoots", - "enzymejl_returnRoots_v"}) + "enzymejl_returnRoots_v", "enzymejl_parmtype", + "enzymejl_parmtype_ref", "enzyme_type"}) if (nf->getAttributes().hasParamAttr(attrIndex, name)) { NewF->addParamAttr(attrIndex, nf->getAttributes().getParamAttr(attrIndex, name)); @@ -2736,6 +2804,38 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( ++attrIndex; } +#if LLVM_VERSION_MAJOR >= 14 + for (auto attr : {"enzyme_ta_norecur"}) + if (nf->getAttributes().hasAttributeAtIndex(AttributeList::FunctionIndex, + attr)) { + NewF->addFnAttr( + nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (nf->getAttributes().hasAttributeAtIndex(AttributeList::ReturnIndex, + attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } +#else + for (auto attr : {"enzyme_ta_norecur"}) + if (nf->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) { + NewF->addFnAttr( + nf->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (nf->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + nf->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } +#endif + SmallVector Returns; #if LLVM_VERSION_MAJOR >= 13 CloneFunctionInto(NewF, nf, VMap, CloneFunctionChangeType::LocalChangesOnly, @@ -3037,16 +3137,19 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB, if (!ret->getType()->isFPOrFPVectorTy() && TR.getReturnAnalysis().Inner0().isPossiblePointer()) { if (gutils->isConstantValue(ret)) { - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(ret)[{-1}].isPossiblePointer()) { if (!isa(ret) && !isa(ret)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *inst << " const val: " << *ret; - invertedPtr = unwrap(CustomErrorHandler( - str.c_str(), wrap(inst), ErrorType::MixedActivityError, gutils, - wrap(ret), wrap(&nBuilder))); + if (CustomErrorHandler) + invertedPtr = unwrap(CustomErrorHandler( + str.c_str(), wrap(inst), ErrorType::MixedActivityError, + gutils, wrap(ret), wrap(&nBuilder))); + else + EmitWarning("MixedActivityError", *inst, ss.str()); } } } @@ -3283,7 +3386,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils, gutils->getNewFromOriginal(orig->getParent()) == loopContext.header && loopContext.exitBlocks.size() == 1) { SmallVector Latches; - gutils->OrigLI.getLoopFor(orig->getParent())->getLoopLatches(Latches); + gutils->OrigLI->getLoopFor(orig->getParent())->getLoopLatches(Latches); bool allIncoming = true; for (auto Latch : Latches) { if (activeUses[0] != orig->getIncomingValueForBlock(Latch)) { @@ -3569,6 +3672,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( if (hasMetadata(key.todiff, "enzyme_gradient")) { std::set seen; +#ifndef NDEBUG DIFFE_TYPE subretType = whatType(key.todiff->getReturnType(), DerivativeMode::ReverseModeGradient, /*intAreConstant*/ false, seen); @@ -3576,6 +3680,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( key.todiff->getReturnType()->isEmptyTy()) subretType = DIFFE_TYPE::CONSTANT; assert(subretType == key.retType); +#endif if (key.mode == DerivativeMode::ReverseModeCombined) { auto res = getDefaultFunctionTypeForGradient( @@ -3838,14 +3943,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient( hasTape = false; // res.first.push_back(StructType::get(todiff->getContext(), {})); } else { - llvm::errs() << "expected args: ["; + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Bad function type of custom reverse pass for function " + << key.todiff->getName() << " of type " + << *key.todiff->getFunctionType() << "\n"; + ss << " expected gradient function to have argument types ["; + bool seen = false; for (auto a : res.first) { - llvm::errs() << *a << " "; + if (seen) + ss << ", "; + seen = true; + ss << *a; + } + ss << "]\n"; + ss << " Instead found " << foundcalled->getName() << " of type " + << *foundcalled->getFunctionType() << "\n"; + Value *toshow = key.todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *key.todiff << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(key.todiff), + wrap(context.ip)); + } else if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + } else { + assert(0 && "bad type for custom gradient"); + llvm_unreachable("bad type for custom gradient"); } - llvm::errs() << "]\n"; - llvm::errs() << *foundcalled << "\n"; - assert(0 && "bad type for custom gradient"); - llvm_unreachable("bad type for custom gradient"); } auto st = dyn_cast(foundcalled->getReturnType()); @@ -4017,9 +4148,9 @@ Function *EnzymeLogic::CreatePrimalAndGradient( gutils->computeGuaranteedFrees(); CacheAnalysis CA(gutils->allocationsWithGuaranteedFree, gutils->rematerializableAllocations, gutils->TR, - gutils->OrigAA, gutils->oldFunc, + *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, key.mode, omp); const std::map> overwritten_args_map = (augmenteddata) ? augmenteddata->overwritten_args_map @@ -4171,12 +4302,12 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } } - AdjointGenerator maker( - key.mode, gutils, key.constant_args, key.retType, getIndex, - overwritten_args_map, - /*returnuses*/ nullptr, augmenteddata, &replacedReturns, - unnecessaryValues, unnecessaryInstructions, unnecessaryStores, - guaranteedUnreachable, dretAlloca); + AdjointGenerator maker(key.mode, gutils, key.constant_args, key.retType, + getIndex, overwritten_args_map, + /*returnuses*/ nullptr, augmenteddata, + &replacedReturns, unnecessaryValues, + unnecessaryInstructions, unnecessaryStores, + guaranteedUnreachable, dretAlloca); for (BasicBlock &oBB : *gutils->oldFunc) { // Don't create derivatives for code that results in termination @@ -4282,13 +4413,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( } auto store = entryBuilder.CreateStore( Constant::getNullValue(g.getValueType()), &g); -#if LLVM_VERSION_MAJOR >= 16 - if (g.getAlign()) - store->setAlignment(g.getAlign().value()); -#else if (g.getAlign()) - store->setAlignment(g.getAlign().getValue()); -#endif + store->setAlignment(*g.getAlign()); } } if (sharedBlock) { @@ -4432,7 +4558,29 @@ Function *EnzymeLogic::CreateForwardDiff( "unknown derivative for function -- metadata incorrect"); } auto md2 = cast(md); + assert(md2); assert(md2->getNumOperands() == 1); + if (!md2->getOperand(0)) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Failed to use custom forward mode derivative for " + << todiff->getName() << "\n"; + ss << " found metadata (but null op0) " << *md2 << "\n"; + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + return ForwardCachedFunctions[tup] = nullptr; + } + if (!isa(md2->getOperand(0))) { + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Failed to use custom forward mode derivative for " + << todiff->getName() << "\n"; + ss << " found metadata (but not constantasmetadata) " + << *md2->getOperand(0) << "\n"; + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + return ForwardCachedFunctions[tup] = nullptr; + } auto gvemd = cast(md2->getOperand(0)); auto foundcalled = cast(gvemd->getValue()); @@ -4645,7 +4793,7 @@ Function *EnzymeLogic::CreateForwardDiff( SmallPtrSet unnecessaryInstructions; SmallPtrSet unnecessaryStores; - AdjointGenerator *maker; + AdjointGenerator *maker; std::unique_ptr> can_modref_map; if (mode == DerivativeMode::ForwardModeSplit) { @@ -4654,10 +4802,10 @@ Function *EnzymeLogic::CreateForwardDiff( gutils->computeGuaranteedFrees(); CacheAnalysis CA( gutils->allocationsWithGuaranteedFree, - gutils->rematerializableAllocations, gutils->TR, gutils->OrigAA, + gutils->rematerializableAllocations, gutils->TR, *gutils->OrigAA, gutils->oldFunc, PPC.FAM.getResult(*gutils->oldFunc), - gutils->OrigLI, gutils->OrigDT, TLI, guaranteedUnreachable, + *gutils->OrigLI, *gutils->OrigDT, TLI, guaranteedUnreachable, _overwritten_argsPP, mode, omp); const std::map> overwritten_args_map = CA.compute_overwritten_args_for_callsites(); @@ -4685,7 +4833,7 @@ Function *EnzymeLogic::CreateForwardDiff( calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, unnecessaryInstructions, gutils, TLI); - maker = new AdjointGenerator( + maker = new AdjointGenerator( mode, gutils, constant_args, retType, getIndex, overwritten_args_map, /*returnuses*/ nullptr, augmenteddata, nullptr, unnecessaryValues, unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, @@ -4738,11 +4886,11 @@ Function *EnzymeLogic::CreateForwardDiff( calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores, unnecessaryInstructions, gutils, TLI); - maker = new AdjointGenerator( - mode, gutils, constant_args, retType, nullptr, {}, - /*returnuses*/ nullptr, nullptr, nullptr, unnecessaryValues, - unnecessaryInstructions, unnecessaryStores, guaranteedUnreachable, - nullptr); + maker = + new AdjointGenerator(mode, gutils, constant_args, retType, nullptr, {}, + /*returnuses*/ nullptr, nullptr, nullptr, + unnecessaryValues, unnecessaryInstructions, + unnecessaryStores, guaranteedUnreachable, nullptr); } for (BasicBlock &oBB : *gutils->oldFunc) { @@ -4813,23 +4961,29 @@ Function *EnzymeLogic::CreateForwardDiff( return nf; } -static Type *getTypeForWidth(LLVMContext &ctx, unsigned width) { - switch (width) { - default: - return llvm::Type::getIntNTy(ctx, width); - case 64: - return llvm::Type::getDoubleTy(ctx); - case 32: - return llvm::Type::getFloatTy(ctx); - case 16: - return llvm::Type::getHalfTy(ctx); - } +static Value *floatValTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatTruncation truncation) { + Type *toTy = truncation.getToType(B.getContext()); + if (auto vty = dyn_cast(v->getType())) + toTy = VectorType::get(toTy, vty->getElementCount()); + return B.CreateFPTrunc(v, toTy, "enzyme_trunc"); } -static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, - unsigned fromwidth, unsigned towidth) { - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); - Type *toTy = getTypeForWidth(B.getContext(), towidth); +static Value *floatValExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatTruncation truncation) { + Type *fromTy = truncation.getFromType(B.getContext()); + if (auto vty = dyn_cast(v->getType())) + fromTy = VectorType::get(fromTy, vty->getElementCount()); + return B.CreateFPExt(v, fromTy, "enzyme_exp"); +} + +static Value *floatMemTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatTruncation truncation) { + if (isa(v->getType())) + report_fatal_error("vector operations not allowed in mem trunc mode"); + + Type *fromTy = truncation.getFromType(B.getContext()); + Type *toTy = truncation.getToType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); B.CreateStore( @@ -4838,13 +4992,16 @@ static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock, toTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(toTy))); } -static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, - unsigned fromwidth, unsigned towidth) { - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); +static Value *floatMemExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, + FloatTruncation truncation) { + if (isa(v->getType())) + report_fatal_error("vector operations not allowed in mem trunc mode"); + + Type *fromTy = truncation.getFromType(B.getContext()); if (!tmpBlock) tmpBlock = B.CreateAlloca(fromTy); - auto c0 = - Constant::getNullValue(llvm::Type::getIntNTy(B.getContext(), fromwidth)); + auto c0 = Constant::getNullValue( + llvm::Type::getIntNTy(B.getContext(), truncation.getFromTypeWidth())); B.CreateStore( c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType()))); B.CreateStore( @@ -4856,27 +5013,51 @@ static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock, class TruncateGenerator : public llvm::InstVisitor { private: ValueToValueMapTy &originalToNewFn; - unsigned fromwidth; - unsigned towidth; + FloatTruncation truncation; Type *fromType; Type *toType; Function *oldFunc; Function *newFunc; AllocaInst *tmpBlock; + TruncateMode mode; EnzymeLogic &Logic; + LLVMContext &ctx; public: - TruncateGenerator(ValueToValueMapTy &originalToNewFn, unsigned fromwidth, - unsigned towidth, Function *oldFunc, Function *newFunc, - EnzymeLogic &Logic) - : originalToNewFn(originalToNewFn), fromwidth(fromwidth), - towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) { + TruncateGenerator(ValueToValueMapTy &originalToNewFn, + FloatTruncation truncation, Function *oldFunc, + Function *newFunc, TruncateMode mode, EnzymeLogic &Logic) + : originalToNewFn(originalToNewFn), truncation(truncation), + oldFunc(oldFunc), newFunc(newFunc), mode(mode), Logic(Logic), + ctx(newFunc->getContext()) { IRBuilder<> B(&newFunc->getEntryBlock().front()); - fromType = getTypeForWidth(B.getContext(), fromwidth); - toType = getTypeForWidth(B.getContext(), towidth); + fromType = truncation.getFromType(ctx); + toType = truncation.getToType(ctx); + if (fromType == toType) + assert(truncation.isToMPFR()); + + if (mode == TruncMemMode) + tmpBlock = B.CreateAlloca(fromType); + else + tmpBlock = nullptr; + + if (truncation.isToMPFR()) { + switch (mode) { + case TruncMemMode: + llvm::report_fatal_error( + "truncation to MPFR not supported in memory mode."); + case TruncOpMode: + case TruncOpFullModuleMode: + break; + } + } + } - tmpBlock = B.CreateAlloca(fromType); + void checkHandled(llvm::Instruction &inst) { + // if (all_of(inst.getOperandList(), + // [&](Use *use) { return use->get()->getType() == fromType; })) + // todo(inst); } void visitInstruction(llvm::Instruction &inst) { @@ -4891,7 +5072,7 @@ class TruncateGenerator : public llvm::InstVisitor { break; } - todo(inst); + checkHandled(inst); } Type *getFromType() { return fromType; } @@ -4899,11 +5080,28 @@ class TruncateGenerator : public llvm::InstVisitor { Type *getToType() { return toType; } Value *truncate(IRBuilder<> &B, Value *v) { - return floatTruncate(B, v, tmpBlock, fromwidth, towidth); + switch (mode) { + case TruncMemMode: + assert(!truncation.isToMPFR()); + return floatMemTruncate(B, v, tmpBlock, truncation); + case TruncOpMode: + case TruncOpFullModuleMode: + if (truncation.isToMPFR()) + return v; + return floatValTruncate(B, v, tmpBlock, truncation); + } + llvm_unreachable("Unknown trunc mode"); } Value *expand(IRBuilder<> &B, Value *v) { - return floatExpand(B, v, tmpBlock, fromwidth, towidth); + switch (mode) { + case TruncMemMode: + return floatMemExpand(B, v, tmpBlock, truncation); + case TruncOpMode: + case TruncOpFullModuleMode: + return floatValExpand(B, v, tmpBlock, truncation); + } + llvm_unreachable("Unknown trunc mode"); } void todo(llvm::Instruction &I) { @@ -4949,46 +5147,109 @@ class TruncateGenerator : public llvm::InstVisitor { void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { - Value *newCI = nullptr; - auto newI = getNewFromOriginal(&CI); - std::string oldName = CI.getName().str(); - newI->setName(""); - if (CI.getSrcTy() == getFromType()) { - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); - } - if (CI.getDestTy() == getToType()) { + switch (mode) { + case TruncMemMode: { + Value *newCI = nullptr; auto newI = getNewFromOriginal(&CI); - IRBuilder<> B(newI); - newCI = B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), - CI.getDestTy(), oldName); + std::string oldName = CI.getName().str(); + newI->setName(""); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (CI.getDestTy() == getToType()) { + auto newI = getNewFromOriginal(&CI); + IRBuilder<> B(newI); + newCI = + B.CreateCast(CI.getOpcode(), getNewFromOriginal(CI.getOperand(0)), + CI.getDestTy(), oldName); + } + if (newCI) { + newI->replaceAllUsesWith(newCI); + newI->eraseFromParent(); + } + return; } - if (newCI) { - newI->replaceAllUsesWith(newCI); - newI->eraseFromParent(); + case TruncOpMode: + case TruncOpFullModuleMode: + return; } - return; } void visitSelectInst(llvm::SelectInst &SI) { - auto newI = getNewFromOriginal(&SI); - IRBuilder<> B(newI); - auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); - auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); - auto nres = cast( - B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - return; + switch (mode) { + case TruncMemMode: { + auto newI = getNewFromOriginal(&SI); + IRBuilder<> B(newI); + auto newT = truncate(B, getNewFromOriginal(SI.getTrueValue())); + auto newF = truncate(B, getNewFromOriginal(SI.getFalseValue())); + auto nres = cast( + B.CreateSelect(getNewFromOriginal(SI.getCondition()), newT, newF)); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); + return; + } + case TruncOpMode: + case TruncOpFullModuleMode: + return; + } + llvm_unreachable(""); } void visitExtractElementInst(llvm::ExtractElementInst &EEI) { return; } void visitInsertElementInst(llvm::InsertElementInst &EEI) { return; } void visitShuffleVectorInst(llvm::ShuffleVectorInst &EEI) { return; } void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } + CallInst *createMPFRCall(llvm::IRBuilder<> &B, llvm::Instruction &I, + llvm::Type *RetTy, + SmallVectorImpl &ArgsIn) { + std::string Name; + if (auto BO = dyn_cast(&I)) { + Name = "binop_" + std::string(BO->getOpcodeName()); + } else if (auto II = dyn_cast(&I)) { + auto FOp = II->getCalledFunction(); + assert(FOp); + Name = "intr_" + std::string(FOp->getName()); + for (auto &C : Name) + if (C == '.') + C = '_'; + } else if (auto CI = dyn_cast(&I)) { + if (auto F = CI->getCalledFunction()) + Name = "func_" + std::string(F->getName()); + else + llvm_unreachable( + "Unexpected indirect call inst for conversion to MPFR"); + } else { + llvm_unreachable("Unexpected instruction for conversion to MPFR"); + } + + std::string MangledName = + std::string("__enzyme_mpfr_") + truncation.mangleFrom() + "_" + Name; + auto F = newFunc->getParent()->getFunction(MangledName); + SmallVector Args(ArgsIn.begin(), ArgsIn.end()); + Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); + Args.push_back(B.getInt64(truncation.getTo().significandWidth)); + if (!F) { + SmallVector ArgTypes; + for (auto Arg : Args) + ArgTypes.push_back(Arg->getType()); + FunctionType *FnTy = + FunctionType::get(RetTy, ArgTypes, /*is_vararg*/ false); + F = Function::Create(FnTy, Function::ExternalLinkage, MangledName, + newFunc->getParent()); + } + return cast(B.CreateCall(F, Args)); + } void visitBinaryOperator(llvm::BinaryOperator &BO) { + auto oldLHS = BO.getOperand(0); + auto oldRHS = BO.getOperand(1); + + if (oldLHS->getType() != getFromType() && + oldRHS->getType() != getFromType()) + return; switch (BO.getOpcode()) { default: @@ -5006,60 +5267,25 @@ class TruncateGenerator : public llvm::InstVisitor { case BinaryOperator::And: case BinaryOperator::Or: case BinaryOperator::Xor: + assert(0 && "Invalid binop opcode for float arg"); return; } - if (towidth == 32 || towidth == 16 || towidth == 64) { - auto newI = getNewFromOriginal(&BO); - IRBuilder<> B(newI); - auto newLHS = truncate(B, getNewFromOriginal(BO.getOperand(0))); - auto newRHS = truncate(B, getNewFromOriginal(BO.getOperand(1))); - switch (BO.getOpcode()) { - default: - break; - case BinaryOperator::FMul: { - auto nres = cast(B.CreateFMul(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FAdd: { - auto nres = cast(B.CreateFAdd(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FSub: { - auto nres = cast(B.CreateFSub(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FDiv: { - auto nres = cast(B.CreateFDiv(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - case BinaryOperator::FRem: { - auto nres = cast(B.CreateFRem(newLHS, newRHS)); - nres->takeName(newI); - nres->copyIRFlags(newI); - newI->replaceAllUsesWith(expand(B, nres)); - newI->eraseFromParent(); - } - return; - } + auto newI = getNewFromOriginal(&BO); + IRBuilder<> B(newI); + auto newLHS = truncate(B, getNewFromOriginal(oldLHS)); + auto newRHS = truncate(B, getNewFromOriginal(oldRHS)); + Instruction *nres = nullptr; + if (truncation.isToMPFR()) { + SmallVector Args({newLHS, newRHS}); + nres = createMPFRCall(B, BO, truncation.getToType(ctx), Args); + } else { + nres = cast(B.CreateBinOp(BO.getOpcode(), newLHS, newRHS)); } - todo(BO); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceAllUsesWith(expand(B, nres)); + newI->eraseFromParent(); return; } void visitMemSetInst(llvm::MemSetInst &MS) { visitMemSetCommon(MS); } @@ -5080,18 +5306,18 @@ class TruncateGenerator : public llvm::InstVisitor { return; } void visitFenceInst(llvm::FenceInst &FI) { return; } - void visitIntrinsicInst(llvm::IntrinsicInst &II) { - SmallVector orig_ops(II.arg_size()); - for (unsigned i = 0; i < II.arg_size(); ++i) - orig_ops[i] = II.getOperand(i); - if (handleAdjointForIntrinsic(II.getIntrinsicID(), II, orig_ops)) - return; - bool hasFromType = false; - auto newI = cast(getNewFromOriginal(&II)); + bool handleIntrinsic(llvm::CallInst &CI, Intrinsic::ID ID) { + auto newI = cast(getNewFromOriginal(&CI)); IRBuilder<> B(newI); - SmallVector new_ops(II.arg_size()); - for (unsigned i = 0; i < II.arg_size(); ++i) { + + SmallVector orig_ops(CI.arg_size()); + for (unsigned i = 0; i < CI.arg_size(); ++i) + orig_ops[i] = CI.getOperand(i); + + bool hasFromType = false; + SmallVector new_ops(CI.arg_size()); + for (unsigned i = 0; i < CI.arg_size(); ++i) { if (orig_ops[i]->getType() == getFromType()) { new_ops[i] = truncate(B, getNewFromOriginal(orig_ops[i])); hasFromType = true; @@ -5099,27 +5325,33 @@ class TruncateGenerator : public llvm::InstVisitor { new_ops[i] = getNewFromOriginal(orig_ops[i]); } } - Type *retTy = II.getType(); - if (II.getType() == getFromType()) { + Type *retTy = CI.getType(); + if (CI.getType() == getFromType()) { hasFromType = true; retTy = getToType(); } if (!hasFromType) - return; - - // TODO check that the intrinsic is overloaded + return false; - CallInst *intr; - Value *nres = intr = createIntrinsicCall(B, II.getIntrinsicID(), retTy, - new_ops, &II, II.getName()); - if (II.getType() == getFromType()) + Instruction *intr = nullptr; + Value *nres = nullptr; + if (truncation.isToMPFR()) { + nres = intr = createMPFRCall(B, CI, retTy, new_ops); + } else { + // TODO check that the intrinsic is overloaded + nres = intr = + createIntrinsicCall(B, ID, retTy, new_ops, &CI, CI.getName()); + } + if (newI->getType() == getFromType()) nres = expand(B, nres); intr->copyIRFlags(newI); newI->replaceAllUsesWith(nres); newI->eraseFromParent(); - - return; + return true; + } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { + handleIntrinsic(II, II.getIntrinsicID()); } void visitReturnInst(llvm::ReturnInst &I) { return; } @@ -5200,40 +5432,55 @@ class TruncateGenerator : public llvm::InstVisitor { Value *GetShadow(RequestContext &ctx, Value *v) { if (auto F = dyn_cast(v)) - return Logic.CreateTruncateFunc(ctx, F, fromwidth, towidth); + return Logic.CreateTruncateFunc(ctx, F, truncation, mode); llvm::errs() << " unknown get truncated func: " << *v << "\n"; llvm_unreachable("unknown get truncated func"); return v; } // Return - void visitCallInst(llvm::CallInst &call) { + void visitCallInst(llvm::CallInst &CI) { + Intrinsic::ID ID; + StringRef funcName = getFuncNameFromCall(const_cast(&CI)); + if (isMemFreeLibMFunction(funcName, &ID)) + if (handleIntrinsic(CI, ID)) + return; + using namespace llvm; - CallInst *const newCall = cast(getNewFromOriginal(&call)); + CallInst *const newCall = cast(getNewFromOriginal(&CI)); IRBuilder<> BuilderZ(newCall); - if (auto called = call.getCalledFunction()) - if (handleKnownCalls(call, called, getFuncNameFromCall(&call), newCall)) + if (auto called = CI.getCalledFunction()) + if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall)) return; - RequestContext ctx(&call, &BuilderZ); - auto val = GetShadow(ctx, getNewFromOriginal(call.getCalledOperand())); - newCall->setCalledOperand(val); + if (mode != TruncOpFullModuleMode) { + RequestContext ctx(&CI, &BuilderZ); + auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand())); + newCall->setCalledOperand(val); + } return; } + void visitFPTruncInst(FPTruncInst &I) { return; } + void visitFPExtInst(FPExtInst &I) { return; } + void visitFPToUIInst(FPToUIInst &I) { return; } + void visitFPToSIInst(FPToSIInst &I) { return; } + void visitUIToFPInst(UIToFPInst &I) { return; } + void visitSIToFPInst(SIToFPInst &I) { return; } }; bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, - unsigned fromwidth, unsigned towidth, - bool isTruncate) { + FloatRepresentation from, + FloatRepresentation to, bool isTruncate) { assert(context.req && context.ip); - if (fromwidth == towidth) { + if (from == to) { + context.req->replaceAllUsesWith(context.req->getOperand(0)); context.req->eraseFromParent(); return true; } - if (fromwidth < towidth) { + if (from < to) { std::string s; llvm::raw_string_ostream ss(s); ss << "Cannot truncate into a large width\n"; @@ -5247,16 +5494,16 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, } IRBuilderBase &B = *context.ip; - Type *fromTy = getTypeForWidth(B.getContext(), fromwidth); - Type *toTy = getTypeForWidth(B.getContext(), towidth); + Type *fromTy = from.getBuiltinType(B.getContext()); + Type *toTy = to.getType(B.getContext()); Value *converted = nullptr; if (isTruncate) - converted = - floatExpand(B, B.CreateFPTrunc(v, toTy), nullptr, fromwidth, towidth); + converted = floatMemExpand(B, B.CreateFPTrunc(v, toTy), nullptr, + FloatTruncation(from, to)); else - converted = - B.CreateFPExt(floatTruncate(B, v, nullptr, fromwidth, towidth), fromTy); + converted = B.CreateFPExt( + floatMemTruncate(B, v, nullptr, FloatTruncation(from, to)), fromTy); assert(converted); context.req->replaceAllUsesWith(converted); @@ -5267,12 +5514,9 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm::Function *totrunc, - unsigned fromwidth, - unsigned towidth) { - if (fromwidth == towidth) - return totrunc; - - TruncateCacheKey tup(totrunc, fromwidth, towidth); + FloatTruncation truncation, + TruncateMode mode) { + TruncateCacheKey tup(totrunc, truncation, mode); if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) { return TruncateCachedFunctions.find(tup)->second; } @@ -5287,11 +5531,11 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, Type *NewTy = totrunc->getReturnType(); FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); - Function *NewF = - Function::Create(FTy, totrunc->getLinkage(), - "trunc_" + std::to_string(fromwidth) + "_" + - std::to_string(towidth) + totrunc->getName(), - totrunc->getParent()); + std::string truncName = + std::string("__enzyme_done_truncate_") + truncateModeStr(mode) + + "_func_" + truncation.mangleTruncation() + "_" + totrunc->getName().str(); + Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, + totrunc->getParent()); NewF->setLinkage(Function::LinkageTypes::InternalLinkage); @@ -5324,33 +5568,6 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, llvm_unreachable("attempting to truncate function without definition"); } - if (fromwidth < towidth) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "Cannot truncate into a large width\n"; - llvm::Value *toshow = totrunc; - if (context.req) { - toshow = context.req; - ss << " at context: " << *context.req; - } else { - ss << *totrunc << "\n"; - } - if (CustomErrorHandler) { - CustomErrorHandler(ss.str().c_str(), wrap(toshow), - ErrorType::NoDerivative, nullptr, wrap(totrunc), - wrap(context.ip)); - return NewF; - } - if (context.req) { - EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req, - ss.str()); - return NewF; - } - llvm::errs() << "mod: " << *totrunc->getParent() << "\n"; - llvm::errs() << *totrunc << "\n"; - llvm_unreachable("attempting to truncate function without definition"); - } - ValueToValueMapTy originalToNewFn; for (auto i = totrunc->arg_begin(), j = NewF->arg_begin(); @@ -5372,7 +5589,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, NewF->setLinkage(Function::LinkageTypes::InternalLinkage); - TruncateGenerator handle(originalToNewFn, fromwidth, towidth, totrunc, NewF, + TruncateGenerator handle(originalToNewFn, truncation, totrunc, NewF, mode, *this); for (auto &BB : *totrunc) for (auto &I : BB) @@ -5783,37 +6000,47 @@ llvm::Value *EnzymeLogic::CreateNoFree(RequestContext context, cast(CreateNoFree(context, castinst->getOperand(0)))}; return castinst->getWithOperands(reps); } - if (CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of unknown value\n"; - ss << *todiff << "\n"; - if (context.req) { - ss << " at context: " << *context.req; + if (EnzymeAssumeUnknownNoFree) { + return todiff; + } + + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No create nofree of unknown value\n"; + ss << *todiff << "\n"; + if (context.req) { + ss << " at context: " << *context.req; + } + if (auto I = dyn_cast(todiff)) { + auto fname = I->getParent()->getParent()->getName(); + if (startsWith(fname, "nofree_")) + fname = fname.substr(7); + std::string demangledName = llvm::demangle(fname.str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); } + ss << " within func " << fname << " (" << demangledName << ")\n"; + } + if (CustomErrorHandler) { CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(todiff), wrap(context.ip)); return todiff; } - if (EnzymeAssumeUnknownNoFree) { - return todiff; - } - if (context.req) { - EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, - "Cannot create nofree of instruction-created value: ", *todiff); + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, s); return todiff; } if (auto arg = dyn_cast(todiff)) { auto loc = arg->getDebugLoc(); - EmitFailure("IllegalNoFree", loc, arg, - "Cannot create nofree of instruction-created value: ", *todiff); + EmitFailure("IllegalNoFree", loc, arg, s); return todiff; } - llvm::errs() << " unhandled, create no free of: " << *todiff << "\n"; + llvm::errs() << s; llvm_unreachable("unhandled, create no free"); } @@ -5831,73 +6058,140 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (isAllocationFunction(F->getName(), TLI)) return F; + // clang-format off + StringSet<> NoFreeDemangles = { + "std::basic_ostream>& std::__ostream_insert >(std::basic_ostream >&)", + "std::basic_ostream>::put(char)", + + "std::basic_filebuf>::open(char const*, std::_Ios_Openmode)", + "std::basic_filebuf>::basic_filebuf()", + "std::basic_filebuf>::close()", + + "std::basic_ios>::clear(std::_Ios_Iostate)", + "std::__detail::_Prime_rehash_policy::_M_need_rehash(unsigned long, unsigned long, unsigned long) const", + + "std::basic_streambuf >::xsputn(char const*, long)", + + "std::basic_ios >::init(std::basic_streambuf >*)", + + "std::_Hash_bytes(void const*, unsigned long, unsigned long)", + "unsigned long std::__1::__do_string_hash(char const*, char const*)", + "std::__1::hash::operator()(char const*) const", + + "std::allocator::allocator()", + "std::allocator::~allocator()", + + + "std::__cxx11::basic_string, std::allocator>::basic_string(char const*, std::allocator const&)", + "std::__cxx11::basic_string, std::allocator>::basic_string(std::__cxx11::basic_string, std::allocator>&&)", + "std::__cxx11::basic_string, std::allocator>::_M_construct(unsigned long, char)", + "std::__cxx11::basic_string, std::allocator>::_M_append(char const*, unsigned long)", + "std::__cxx11::basic_string, std::allocator>::_M_assign(std::__cxx11::basic_string, std::allocator> const&)", + "std::__cxx11::basic_string, std::allocator>::_M_replace(unsigned long, unsigned long, char const*, unsigned long)", + "std::__cxx11::basic_string, std::allocator>::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)", + "std::__cxx11::basic_string, std::allocator>::length() const", + "std::__cxx11::basic_string, std::allocator>::data() const", + "std::__cxx11::basic_string, std::allocator>::size() const", + "std::__cxx11::basic_string, std::allocator>::~basic_string()", + "std::__cxx11::basic_string, std::allocator>::compare(char const*) const", + "std::__cxx11::basic_string, std::allocator>::compare(std::__cxx11::basic_string, std::allocator> const&) const", + "std::__cxx11::basic_string, std::allocator>::reserve(unsigned long)", + + "std::__cxx11::basic_string, std::allocator>::~basic_string()", + "std::__cxx11::basic_stringbuf, std::allocator>::overflow(int)", + "std::__cxx11::basic_stringbuf, std::allocator>::pbackfail(int)", + "std::__cxx11::basic_stringbuf, std::allocator>::underflow()", + "std::__cxx11::basic_stringbuf, std::allocator>::_M_sync(char*, unsigned long, unsigned long)", + + "std::__basic_file::~__basic_file()", + + "std::basic_ostream>::flush()", + "std::basic_streambuf>::xsgetn(char*, long)", + + "std::locale::~locale()", + "std::ios_base::ios_base()", + "std::basic_ostream>& " + "std::basic_ostream " + ">::_M_insert(double)", + + // libc++ + "std::__1::basic_string, std::__1::allocator>::basic_string(std::__1::basic_string, std::__1::allocator> const&)", + "std::__1::basic_string, std::__1::allocator>::~basic_string()", + "std::__1::basic_string, std::__1::allocator>::__init(char const*, unsigned long)", + "std::__1::basic_string, std::__1::allocator>::append(char const*, unsigned long)", + "std::__1::basic_string, std::__1::allocator>::data() const", + "std::__1::basic_ostream>::sentry::sentry(std::__1::basic_ostream>&)", + "std::__1::basic_ostream>::sentry::~sentry()", + "std::__1::basic_ostream>::flush()", + "std::__1::ios_base::__set_badbit_and_consider_rethrow()", + "char* std::__1::addressof(char&)", + "char const* std::__1::addressof(char const&)", + "std::__1::random_device::operator()()", + + "std::__1::locale::~locale()", + "std::__1::locale::use_facet(std::__1::locale::id&) const", + "std::__1::ios_base::ios_base()", + "std::__1::ios_base::getloc() const", + "std::__1::ios_base::clear(unsigned int)", + "std::__1::basic_iostream>::~basic_iostream()", + "std::__1::basic_ios>::~basic_ios()", + "std::__1::basic_streambuf>::basic_streambuf()", + "std::__1::basic_streambuf>::~basic_streambuf()", + "std::__1::basic_streambuf>::imbue(std::__1::locale const&)", + "std::__1::basic_streambuf>::setbuf(char*, long)", + "std::__1::basic_streambuf>::sync()", + "std::__1::basic_streambuf>::showmanyc()", + "std::__1::basic_streambuf>::xsgetn(char*, long)", + "std::__1::basic_streambuf>::uflow()", + "std::__1::basic_filebuf>::basic_filebuf()", + "std::__1::basic_filebuf>::~basic_filebuf()", + "std::__1::basic_filebuf>::open(char const*, unsigned int)", + "std::__1::basic_filebuf>::close()", + "std::__1::basic_filebuf>::sync()", + "std::__1::basic_istream>::~basic_istream()", + "virtual thunk to std::__1::basic_istream>::~basic_istream()", + "virtual thunk to std::__1::basic_ostream>::~basic_ostream()", + "std::__1::basic_ifstream>::~basic_ifstream()", + "std::__1::ios_base::init(void*)", + "std::__1::basic_istream>::read(char*, long)", + "std::__1::basic_ostream>::~basic_ostream()", + "std::__1::basic_string, std::__1::allocator>::__init(unsigned long, char)", + "std::__1::basic_ostream>::write(char const*, long)", + }; + const char* NoFreeDemanglesStartsWith[] = { + "std::__1::basic_ostream>::operator<<", + "std::__1::ios_base::imbue", + "std::__1::basic_streambuf>::pubimbue", + "std::__1::basic_stringbuf, std::__1::allocator>::__init_buf_ptrs", + "std::__1::basic_stringbuf, std::__1::allocator>::basic_stringbuf", + "std::__1::basic_string, std::__1::allocator>::operator=", + "std::__1::ctype::widen", + "std::__1::basic_streambuf>::sputn", + }; + // clang-format on + StringSet<> NoFrees = { - "mpfr_greater_p", - "memchr", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEC1EPKcRKS3_", - "_ZSt16__ostream_insertIcSt11char_traitsIcEERSt13basic_ostreamIT_T0_ES6_" - "PKS3_l", - "_ZNSo3putEc", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE7_M_syncEPcmm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE10_M_" - "replaceEmmPKcm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_appendEPKcm", - "_ZNSt13basic_filebufIcSt11char_traitsIcEE4openEPKcSt13_Ios_Openmode", - "_ZNSt9basic_iosIcSt11char_traitsIcEE5clearESt12_Ios_Iostate", - "_ZNSt13basic_filebufIcSt11char_traitsIcEE5closeEv", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE14_M_replace_" - "auxEmmmc", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE12_M_constructEmc", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7reserveEm", - "time", - "strlen", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareERKS4_", - "_ZNKSt8__detail20_Prime_rehash_policy14_M_need_rehashEmmm", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEC1EOS4_", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE6lengthEv", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE4dataEv", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE4sizeEv", - "_ZNKSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE4dataEv" - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEED1Ev", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEED1Ev", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEE6__" - "initEPKcm", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_9allocatorIcEEEC1ERKS5_", - "_ZNSt3__112basic_stringIcNS_11char_traitsIcEENS_" - "9allocatorIcEEE6appendEPKcm", - "_ZNSt12__basic_fileIcED1Ev", - "__cxa_begin_catch", - "__cxa_end_catch", - "_ZNSo5flushEv", - "compress2", - "_ZNSt6localeD1Ev", - "_ZNSt8ios_baseC2Ev", - "_ZNSo9_M_insertIdEERSoT_", - "malloc_usable_size", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEED1Ev", - "_ZNKSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE7compareEPKc", - "_ZNSt13basic_filebufIcSt11char_traitsIcEEC1Ev", - "_ZNSt15basic_streambufIcSt11char_traitsIcEE6xsputnEPKcl", - "_ZNSt9basic_iosIcSt11char_traitsIcEE4initEPSt15basic_streambufIcS1_E", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE8overflowEi", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9pbackfailEi", - "_ZNSt15basic_streambufIcSt11char_traitsIcEE6xsgetnEPcl", - "_ZNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEE9underflowEv", - "_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_assignERKS4_", - "_ZNSaIcED1Ev", - "_ZNSaIcEC1Ev", - "_ZSt11_Hash_bytesPKvmm", - "_ZNSt3__116__do_string_hashIPKcEEmT_S3_", - "_ZNKSt3__14hashIPKcEclES2_", - "_ZNSt3__19addressofIcEEPT_RS1_", - "_ZNSt3__19addressofIKcEEPT_RS2_", - "_ZNSt3__113random_deviceclEv", + "mpfr_greater_p", "memchr", "time", "strlen", + "__cxa_begin_catch", "__cxa_end_catch", "compress2", "malloc_usable_size", "MPI_Allreduce", }; if (startsWith(F->getName(), "_ZNSolsE") || NoFrees.count(F->getName())) return F; + std::string demangledName = llvm::demangle(F->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangledName.find("> >", start)) != std::string::npos) { + demangledName.replace(start, 3, ">>"); + } + if (NoFreeDemangles.count(demangledName)) + return F; + + for (auto Name : NoFreeDemanglesStartsWith) + if (startsWith(demangledName, Name)) + return F; + switch (F->getIntrinsicID()) { case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: @@ -5915,23 +6209,34 @@ llvm::Function *EnzymeLogic::CreateNoFree(RequestContext context, Function *F) { if (EnzymeEmptyFnInactive) { return F; } - if (CustomErrorHandler) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "No create nofree of empty function " << F->getName() << "\n"; - if (context.req) { - ss << " at context: " << *context.req; - } else { - ss << *F << "\n"; + std::string s; + llvm::raw_string_ostream ss(s); + ss << "No create nofree of empty function (" << demangledName << ") " + << F->getName() << ")\n"; + if (context.req) { + ss << " at context: " << *context.req; + if (auto CB = dyn_cast(context.req)) { + if (auto F = CB->getCalledFunction()) { + std::string demangleF = llvm::demangle(F->getName().str()); + // replace all '> >' with '>>' + size_t start = 0; + while ((start = demangleF.find("> >", start)) != std::string::npos) { + demangleF.replace(start, 3, ">>"); + } + ss << " (" << demangleF << ")"; + } } + } else { + ss << *F << "\n"; + } + if (CustomErrorHandler) { CustomErrorHandler(ss.str().c_str(), wrap(context.req), ErrorType::NoDerivative, nullptr, wrap(F), wrap(context.ip)); return F; } if (context.req) { - EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, - "Cannot create nofree of empty function: ", *F); + EmitFailure("IllegalNoFree", context.req->getDebugLoc(), context.req, s); return F; } llvm::errs() << " unhandled, create no free of empty function: " << *F diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 2f1ac9fde496..4bb61e94c8ef 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -42,6 +42,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" #include "ActivityAnalysis.h" #include "FunctionUtils.h" @@ -132,6 +133,21 @@ class AugmentedReturn { isComplete(false) {} }; +/// \p todiff is the function to differentiate +/// \p retType is the activity info of the return. +/// Only allowed to be DUP_ARG or CONSTANT. DUP_NONEED is not allowed, +/// set returnValue to false instead. +/// \p constant_args is the activity info of the arguments +/// \p returnValue is whether the primal's return should also be returned. +/// \p dretUsed is whether the shadow return value should also be returned. +/// Only allowed to be true if retType is CDIFFE_TYPE::DUP_ARG. +/// \p additionalArg is the type (or null) of an additional type in the +/// signature to hold the tape. +/// \p typeInfo is the type info information about the calling context +/// \p _overwritten_args marks whether an argument may be overwritten +/// before loads in the generated function (and thus cannot be cached). +/// \p AtomicAdd is whether to perform all adjoint +/// updates to memory in an atomic way struct ReverseCacheKey { llvm::Function *todiff; DIFFE_TYPE retType; @@ -253,6 +269,133 @@ struct RequestContext { : req(req), ip(ip) {} }; +[[maybe_unused]] static llvm::Type * +getTypeForWidth(llvm::LLVMContext &ctx, unsigned width, bool builtinFloat) { + switch (width) { + default: + if (builtinFloat) + llvm::report_fatal_error("Invalid float width requested"); + else + llvm::report_fatal_error( + "Truncation to non builtin float width unsupported"); + case 64: + return llvm::Type::getDoubleTy(ctx); + case 32: + return llvm::Type::getFloatTy(ctx); + case 16: + return llvm::Type::getHalfTy(ctx); + } +} + +enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; +[[maybe_unused]] static const char *truncateModeStr(TruncateMode mode) { + switch (mode) { + case TruncMemMode: + return "mem"; + case TruncOpMode: + return "op"; + case TruncOpFullModuleMode: + return "op_full_module"; + } + llvm_unreachable("Invalid truncation mode"); +} + +struct FloatRepresentation { + // |_|__________|_________________| + // ^ ^ ^ + // sign bit exponent significand + // + // value = (sign) * significand * 2 ^ exponent + unsigned exponentWidth; + unsigned significandWidth; + + FloatRepresentation(unsigned e, unsigned s) + : exponentWidth(e), significandWidth(s) {} + + unsigned getTypeWidth() const { return 1 + exponentWidth + significandWidth; } + + bool canBeBuiltin() const { + unsigned w = getTypeWidth(); + return (w == 16 && significandWidth == 10) || + (w == 32 && significandWidth == 23) || + (w == 64 && significandWidth == 52); + } + + llvm::Type *getBuiltinType(llvm::LLVMContext &ctx) const { + if (!canBeBuiltin()) + return nullptr; + return getTypeForWidth(ctx, getTypeWidth(), /*builtinFloat=*/true); + } + + llvm::Type *getType(llvm::LLVMContext &ctx) const { + llvm::Type *builtinType = getBuiltinType(ctx); + if (builtinType) + return builtinType; + llvm_unreachable("TODO MPFR"); + } + + bool operator==(const FloatRepresentation &other) const { + return other.exponentWidth == exponentWidth && + other.significandWidth == significandWidth; + } + bool operator<(const FloatRepresentation &other) const { + return std::tuple(exponentWidth, significandWidth) < + std::tuple(other.exponentWidth, other.significandWidth); + } + std::string to_string() const { + return std::to_string(getTypeWidth()) + "_" + + std::to_string(significandWidth); + } +}; + +struct FloatTruncation { +private: + FloatRepresentation from, to; + +public: + FloatTruncation(FloatRepresentation From, FloatRepresentation To) + : from(From), to(To) { + if (!From.canBeBuiltin()) + llvm::report_fatal_error("Float truncation `from` type is not builtin."); + if (From.exponentWidth < To.exponentWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider exponent than `to`."); + if (From.significandWidth < To.significandWidth) + llvm::report_fatal_error("Float truncation `from` type must have " + "a wider wsignificand than `to`."); + if (From == To) + llvm::report_fatal_error( + "Float truncation `from` and `to` type must not be the same."); + } + FloatRepresentation getTo() { return to; } + unsigned getFromTypeWidth() { return from.getTypeWidth(); } + unsigned getToTypeWidth() { return to.getTypeWidth(); } + llvm::Type *getFromType(llvm::LLVMContext &ctx) { + return from.getBuiltinType(ctx); + } + bool isToMPFR() { return !to.canBeBuiltin(); } + llvm::Type *getToType(llvm::LLVMContext &ctx) { + if (to.canBeBuiltin()) { + return to.getBuiltinType(ctx); + } else { + assert(isToMPFR()); + // Currently we do not support TruncMemMode for MPFR, and we provide + // runtime wrappers around MPFR for each builtin `from` type + return from.getBuiltinType(ctx); + } + } + bool operator==(const FloatTruncation &other) const { + return from == other.from && to == other.to; + } + bool operator<(const FloatTruncation &other) const { + return std::tuple(from, to) < std::tuple(other.from, other.to); + } + std::string mangleTruncation() const { + return from.to_string() + "to" + to.to_string(); + } + std::string mangleFrom() const { return from.to_string(); } +}; + class EnzymeLogic { public: PreProcessCache PPC; @@ -359,9 +502,10 @@ class EnzymeLogic { /// \p returnUsed is whether the primal's return should also be returned /// \p typeInfo is the type info information about the calling context /// \p _overwritten_args marks whether an argument may be rewritten before - /// loads in the generated function (and thus cannot be cached). \p - /// forceAnonymousTape forces the tape to be an i8* rather than the true tape - /// structure \p AtomicAdd is whether to perform all adjoint updates to + /// loads in the generated function (and thus cannot be cached). + /// \p forceAnonymousTape forces the tape to be an i8* rather than the true + /// tape structure + /// \p AtomicAdd is whether to perform all adjoint updates to /// memory in an atomic way const AugmentedReturn &CreateAugmentedPrimal( RequestContext context, llvm::Function *todiff, DIFFE_TYPE retType, @@ -453,20 +597,8 @@ class EnzymeLogic { /// Create the reverse pass, or combined forward+reverse derivative function. /// \p context the instruction which requested this derivative (or null). - /// \p todiff is the function to differentiate - /// \p retType is the activity info of the return - /// \p constant_args is the activity info of the arguments - /// \p returnValue is whether the primal's return should also be returned - /// \p dretUsed is whether the shadow return value should also be returned - /// \p additionalArg is the type (or null) of an additional type in the - /// signature to hold the tape. - /// \p typeInfo is the type info information about the calling context - /// \p _overwritten_args marks whether an argument may be rewritten - /// before loads in the generated function (and thus cannot be cached). /// \p augmented is the data structure created by prior call to an /// augmented forward pass - /// \p AtomicAdd is whether to perform all adjoint - /// updates to memory in an atomic way llvm::Function *CreatePrimalAndGradient(RequestContext context, const ReverseCacheKey &&key, TypeAnalysis &TA, @@ -511,13 +643,15 @@ class EnzymeLogic { llvm::ArrayRef arg_types, BATCH_TYPE ret_type); - using TruncateCacheKey = std::tuple; + using TruncateCacheKey = + std::tuple; std::map TruncateCachedFunctions; llvm::Function *CreateTruncateFunc(RequestContext context, llvm::Function *tobatch, - unsigned fromwidth, unsigned towidth); + FloatTruncation truncation, + TruncateMode mode); bool CreateTruncateValue(RequestContext context, llvm::Value *addr, - unsigned fromwidth, unsigned towidth, + FloatRepresentation from, FloatRepresentation to, bool isTruncate); /// Create a traced version of a function diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 4c3370cf4cf9..3e2ade41b1c9 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -754,6 +754,16 @@ void PreProcessCache::AlwaysInline(Function *NewF) { for (auto CI : ToInline) { InlineFunctionInfo IFI; +#if LLVM_VERSION_MAJOR >= 18 + auto F = CI->getCalledFunction(); + if (CI->getParent()->IsNewDbgInfoFormat != F->IsNewDbgInfoFormat) { + if (CI->getParent()->IsNewDbgInfoFormat) { + F->convertToNewDbgValues(); + } else { + F->convertFromNewDbgValues(); + } + } +#endif InlineFunction(*CI, IFI); } } @@ -2170,8 +2180,24 @@ Function *PreProcessCache::CloneFunctionWithReturns( VMapO->getMDMap() = VMap.getMDMap(); } + for (auto attr : {"enzyme_ta_norecur"}) + if (F->getAttributes().hasAttribute(AttributeList::FunctionIndex, attr)) { + NewF->addAttribute( + AttributeList::FunctionIndex, + F->getAttributes().getAttribute(AttributeList::FunctionIndex, attr)); + } + + for (auto attr : + {"enzyme_type", "enzymejl_parmtype", "enzymejl_parmtype_ref"}) + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, attr)) { + NewF->addAttribute( + AttributeList::ReturnIndex, + F->getAttributes().getAttribute(AttributeList::ReturnIndex, attr)); + } + bool hasPtrInput = false; unsigned ii = 0, jj = 0; + for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) { if (F->hasParamAttribute(ii, Attribute::StructRet)) { NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret")); @@ -2194,10 +2220,15 @@ Function *PreProcessCache::CloneFunctionWithReturns( // Attribute::ElementType)); #endif } - for (auto ty : PrimalParamAttrsToPreserve) - if (F->getAttributes().hasParamAttr(ii, ty)) { - auto attr = F->getAttributes().getParamAttr(ii, ty); - NewF->addParamAttr(jj, attr); + for (auto attr : + {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) + if (F->getAttributes().hasParamAttr(ii, attr)) { + NewF->addParamAttr(jj, F->getAttributes().getParamAttr(ii, attr)); + for (auto ty : PrimalParamAttrsToPreserve) + if (F->getAttributes().hasParamAttr(ii, ty)) { + auto attr = F->getAttributes().getParamAttr(ii, ty); + NewF->addParamAttr(jj, attr); + } } if (constant_args[ii] == DIFFE_TYPE::CONSTANT) { if (!i->hasByValAttr()) @@ -2212,8 +2243,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( << " nonconstant arg " << *j << "\n"; } - // Always remove nonnull/noundef since the caller may choose to pass undef - // as an arg if provably it will not be used in the reverse pass + // Always remove nonnull/noundef since the caller may choose to pass + // undef as an arg if provably it will not be used in the reverse pass if (constant_args[ii] == DIFFE_TYPE::DUP_NONEED || mode == DerivativeMode::ReverseModeGradient) { if (F->hasParamAttribute(ii, Attribute::NonNull)) { @@ -2236,6 +2267,14 @@ Function *PreProcessCache::CloneFunctionWithReturns( NewF->addParamAttr(jj + 1, attr); } + for (auto attr : + {"enzymejl_parmtype", "enzymejl_parmtype_ref", "enzyme_type"}) + if (F->getAttributes().hasParamAttr(ii, attr)) { + if (width == 1) + NewF->addParamAttr(jj + 1, + F->getAttributes().getParamAttr(ii, attr)); + } + if (F->getAttributes().hasParamAttr(ii, "enzymejl_returnRoots")) { if (width == 1) { NewF->addParamAttr(jj + 1, F->getAttributes().getParamAttr( @@ -2247,7 +2286,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( #if LLVM_VERSION_MAJOR >= 13 // TODO // NewF->addParamAttr(jj + 1, - // F->getParamAttribute(ii, Attribute::ElementType)); + // F->getParamAttribute(ii, + // Attribute::ElementType)); #endif } @@ -2266,7 +2306,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // jj + 1, // Attribute::get(F->getContext(), // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, Attribute::StructRet) + // F->getParamAttribute(ii, + // Attribute::StructRet) // .getValueAsType())); #endif } else { @@ -2283,7 +2324,8 @@ Function *PreProcessCache::CloneFunctionWithReturns( // jj + 1, // Attribute::get(F->getContext(), // Attribute::AttrKind::ElementType, - // F->getParamAttribute(ii, Attribute::StructRet) + // F->getParamAttribute(ii, + // Attribute::StructRet) // .getValueAsType())); #endif } @@ -2495,19 +2537,23 @@ void ReplaceFunctionImplementation(Module &M) { } void PreProcessCache::optimizeIntermediate(Function *F) { - PromotePass().run(*F, FAM); + PreservedAnalyses PA; + PA = PromotePass().run(*F, FAM); + FAM.invalidate(*F, PA); #if LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - GVNPass().run(*F, FAM); + PA = GVNPass().run(*F, FAM); #else - GVN().run(*F, FAM); + PA = GVN().run(*F, FAM); #endif + FAM.invalidate(*F, PA); #if LLVM_VERSION_MAJOR >= 16 && !defined(FLANG) - SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM); + PA = SROAPass(llvm::SROAOptions::PreserveCFG).run(*F, FAM); #elif LLVM_VERSION_MAJOR >= 14 && !defined(FLANG) - SROAPass().run(*F, FAM); + PA = SROAPass().run(*F, FAM); #else - SROA().run(*F, FAM); + PA = SROA().run(*F, FAM); #endif + FAM.invalidate(*F, PA); if (EnzymeSelectOpt) { #if LLVM_VERSION_MAJOR >= 12 @@ -2518,8 +2564,10 @@ void PreProcessCache::optimizeIntermediate(Function *F) { /*bool SwitchToLookup=*/false, /*bool CanonicalLoops=*/true, /*bool SinkCommon=*/true, /*AssumptionCache *AssumpCache=*/nullptr); #endif - SimplifyCFGPass(scfgo).run(*F, FAM); - CorrelatedValuePropagationPass().run(*F, FAM); + PA = SimplifyCFGPass(scfgo).run(*F, FAM); + FAM.invalidate(*F, PA); + PA = CorrelatedValuePropagationPass().run(*F, FAM); + FAM.invalidate(*F, PA); SelectOptimization(F); } // EarlyCSEPass(/*memoryssa*/ true).run(*F, FAM); @@ -2529,8 +2577,10 @@ void PreProcessCache::optimizeIntermediate(Function *F) { ReplaceFunctionImplementation(*F->getParent()); - PreservedAnalyses PA; - FAM.invalidate(*F, PA); + { + PreservedAnalyses PA; + FAM.invalidate(*F, PA); + } #if LLVM_VERSION_MAJOR < 14 using OptimizationLevel = llvm::PassBuilder::OptimizationLevel; @@ -3646,14 +3696,13 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } /* - // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 != - c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; j++) - if (auto c0 = dyn_cast(cur->getOperand(j))) - if (auto cmp0 = dyn_cast(c0->getOperand(0))) - if (auto c1 = dyn_cast(cur->getOperand(1-j))) - if (auto cmp1 = dyn_cast(c0->getOperand(0))) - if (cmp0->getPredicate() == ICmpInst::ICMP_EQ && - cmp1->getPredicate() == ICmpInst::ICMP_EQ) + // add (ext (x == expr )), ( ext (x == expr + 1)) -> -expr == c2 ) and c1 + != c2 -> false if (cur->getOpcode() == Instruction::Add) for (int j=0; j<2; + j++) if (auto c0 = dyn_cast(cur->getOperand(j))) if (auto cmp0 = + dyn_cast(c0->getOperand(0))) if (auto c1 = + dyn_cast(cur->getOperand(1-j))) if (auto cmp1 = + dyn_cast(c0->getOperand(0))) if (cmp0->getPredicate() == + ICmpInst::ICMP_EQ && cmp1->getPredicate() == ICmpInst::ICMP_EQ) { for (size_t i0 = 0; i0 < 2; i0++) for (size_t i1 = 0; i1 < 2; i1++) @@ -3686,7 +3735,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, if (auto C = dyn_cast(fcmp->getOperand(i))) { if (C->isZero()) { // (a1*a2*...an) == 0 -> (a1 == 0) || (a2 == 0) || ... (a2 == 0) - // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == 0) + // (a1*a2*...an) != 0 -> ![ (a1 == 0) || (a2 == 0) || ... (a2 == + // 0) // ] if (auto P = isProduct(fcmp->getOperand(1 - i))) { Value *res = nullptr; @@ -3877,8 +3927,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, cur->isExact()) if (auto C2 = dyn_cast(cur->getOperand(1))) if (auto mul = dyn_cast(cur->getOperand(0))) { - // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if C2 - // divides C1 + // (lshr exact (mul a, C1), C2), C -> mul a, (lhsr exact C1, C2) if + // C2 divides C1 if (mul->getOpcode() == Instruction::Mul) for (int i0 = 0; i0 < 2; i0++) if (auto C1 = dyn_cast(mul->getOperand(i0))) { @@ -3905,7 +3955,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, return "IMulDivConst"; } } - // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if C2 + // (lshr exact (add a, C1), C2), C -> add a, (lhsr exact C1, C2) if + // C2 if (mul->getOpcode() == Instruction::Add) for (int i0 = 0; i0 < 2; i0++) if (auto C1 = dyn_cast(mul->getOperand(i0))) { @@ -4115,8 +4166,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, // (a * b) != (c * b) -> (a != c) && b != 0 // auto S1 = SE.getSCEV(cur->getOperand(0)); // auto S2 = SE.getSCEV(cur->getOperand(1)); - // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << " S2: - // " << *S2 << " and " << *cur->getOperand(0) << " " << + // llvm::errs() <<" attempting push: " << *cur << " S1: " << *S1 << " + // S2: " << *S2 << " and " << *cur->getOperand(0) << " " << // *cur->getOperand(1) << "\n"; if (auto mul1 = dyn_cast(cur->getOperand(0))) if (auto mul2 = dyn_cast(cur->getOperand(1))) { @@ -4406,10 +4457,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, for (int j=0; j<2; j++) if (auto CI = dyn_cast(SI->getOperand(1+j))) if (CI->isZero()) { - auto tval = (j == 0) ? CI : pushcse(B.CreateMul(SI->getTrueValue(), - cur->getOperand(1-i), "tval." + cur->getName(), cur->hasNoUnsignedWrap(), - cur->hasNoSignedWrap())); - auto fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(), + auto tval = (j == 0) ? CI : + pushcse(B.CreateMul(SI->getTrueValue(), cur->getOperand(1-i), "tval." + + cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); auto + fval = (j == 1) ? CI : pushcse(B.CreateMul(SI->getFalseValue(), cur->getOperand(1-i), "fval." + cur->getName(), cur->hasNoUnsignedWrap(), cur->hasNoSignedWrap())); @@ -4480,9 +4531,10 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, (and1->getType()->isIntegerTy(1) && and2->getType()->isIntegerTy(1) && and1->getOpcode() == Instruction::And && and2->getOpcode() == Instruction::And) { bool done = false; for (int i1=0; i1<2; i1++) for (int - i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto c1 - = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x = - pushcse(B.CreateZExt(x, inst1->getType())); auto y = and2->getOperand(1-i2); + i2=0; i2<2; i2++) if (and1->getOperand(i1) == and2->getOperand(i2)) { auto + c1 = and1->getOperand(i1); auto x = and1->getOperand(1-i1); x = + pushcse(B.CreateZExt(x, inst1->getType())); auto y = + and2->getOperand(1-i2); y = pushcse(B.CreateZExt(y, inst2->getType())); @@ -5173,8 +5225,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } } - // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), (sitofp - // b) + // fmul a, (sitofp (imul c:const, b)) -> fmul (fmul (a, (sitofp c))), + // (sitofp b) if (cur->getOpcode() == Instruction::FMul && cur->isFast()) { for (int i = 0; i < 2; i++) @@ -5399,8 +5451,8 @@ std::optional fixSparse_inner(Instruction *cur, llvm::Function &F, } Value *sel = pushcse( - B.CreateSelect(condition, ConstantFP::get(cur->getType(), 0.0), - fmul, "mulcsi." + cur->getName())); + B.CreateSelect(condition, ConstantFP::get(cur->getType(), + 0.0), fmul, "mulcsi." + cur->getName())); replaceAndErase(cur, sel); return "FMulSIToFPProp"; @@ -6261,6 +6313,7 @@ class Constraints : public std::enable_shared_from_this { assert(t != Type::None); assert(c.size() != 0); assert(c.size() != 1); +#ifndef NDEBUG SmallVector tmp(c.begin(), c.end()); for (unsigned i = 0; i < tmp.size(); i++) for (unsigned j = 0; j < i; j++) @@ -6283,6 +6336,7 @@ class Constraints : public std::enable_shared_from_this { if (auto s = dyn_cast(tmp[j]->node)) assert(s->getLoop() != tmp[i]->Loop); } +#endif } bool operator==(const Constraints &rhs) const { @@ -6527,8 +6581,8 @@ return true; auto div = ctx.SE.getUDivExpr(MinusX, Y); auto div_e = ctx.SE.getUDivExactExpr(MinusX, Y); - // in case of inexact division, check that these exactly equal for - // replacement + // in case of inexact division, check that these exactly equal + // for replacement if (div == div_e) { if (isEqual) { @@ -6799,8 +6853,8 @@ return true; if (rhs->ty == Type::Intersect || rhs->ty == Type::Compare) { return rhs->andB(shared_from_this(), ctx); } - // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) and - // (c or e)) + // (m or a or b or d) and (m or a or c or e ...) -> m or a or ( (b or d) + // and (c or e)) if (ty == Type::Union && rhs->ty == Type::Union) { if (*this == *rhs->notB(ctx)) { return Constraints::none(); @@ -7173,11 +7227,13 @@ getSparseConditions(bool &legal, Value *val, } } if (scope) - EmitFailure("NoSparsification", I->getDebugLoc(), I, - "F: ", *I->getParent()->getParent(), "\n", + EmitWarning("NoSparsification", *I, " No sparsification: not sparse solvable(icmp): ", *I, " via ", *sub1); - legal = false; + if (SparseDebug) { + llvm::errs() << " getSparse(icmp_dflt, " << *I + << ") = " << *defaultFloat << "\n"; + } return defaultFloat; } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 1d6b823d3721..a8a2cb8bbbdc 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -173,19 +173,19 @@ GradientUtils::GradientUtils( : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(), OrigDT(oldFunc_->empty() - ? *((DominatorTree *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((DominatorTree *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), OrigPDT(oldFunc_->empty() - ? *((PostDominatorTree *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((PostDominatorTree *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), OrigLI(oldFunc_->empty() - ? *((LoopInfo *)nullptr) - : Logic.PPC.FAM.getResult(*oldFunc_)), + ? ((LoopInfo *)nullptr) + : &Logic.PPC.FAM.getResult(*oldFunc_)), OrigSE(oldFunc_->empty() - ? *((ScalarEvolution *)nullptr) - : Logic.PPC.FAM.getResult( + ? ((ScalarEvolution *)nullptr) + : &Logic.PPC.FAM.getResult( *oldFunc_)), notForAnalysis(getGuaranteedUnreachable(oldFunc_)), ATA(oldFunc_->empty() @@ -195,8 +195,8 @@ GradientUtils::GradientUtils( notForAnalysis, TLI_, constantvalues_, activevals_, ReturnActivity)), overwritten_args_map_ptr(nullptr), tid(nullptr), numThreads(nullptr), - OrigAA(oldFunc_->empty() ? *((AAResults *)nullptr) - : Logic.PPC.getAAResultsFromFunction(oldFunc_)), + OrigAA(oldFunc_->empty() ? ((AAResults *)nullptr) + : &Logic.PPC.getAAResultsFromFunction(oldFunc_)), TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) { if (oldFunc_->empty()) return; @@ -246,7 +246,7 @@ GradientUtils::GradientUtils( for (BasicBlock &BB : *oldFunc) { bool legal = true; for (auto BRet : ReturningBlocks) { - if (!(BRet == &BB || OrigDT.dominates(&BB, BRet))) { + if (!(BRet == &BB || OrigDT->dominates(&BB, BRet))) { legal = false; break; } @@ -543,7 +543,7 @@ Value *GradientUtils::getOrInsertConditionalIndex(Value *val, LoopContext &lc, bool GradientUtils::assumeDynamicLoopOfSizeOne(Loop *L) const { if (!EnzymeInactiveDynamic) return false; - auto OL = OrigLI.getLoopFor(isOriginal(L->getHeader())); + auto OL = OrigLI->getLoopFor(isOriginal(L->getHeader())); assert(OL); for (auto OB : OL->getBlocks()) { for (auto &OI : *OB) { @@ -569,17 +569,10 @@ DebugLoc GradientUtils::getNewFromOriginal(const DebugLoc L) const { return L; assert(originalToNewFn.hasMD()); auto opt = originalToNewFn.getMappedMD(L.getAsMDNode()); -#if LLVM_VERSION_MAJOR >= 16 - if (!opt.has_value()) - return L; - assert(opt.has_value()); - return DebugLoc(cast(opt.value())); -#else - if (!opt.hasValue()) + if (!opt) return L; - assert(opt.hasValue()); - return DebugLoc(cast(*opt.getPointer())); -#endif + assert(opt); + return DebugLoc(cast(*opt)); } Value *GradientUtils::getNewFromOriginal(const Value *originst) const { @@ -649,12 +642,14 @@ BasicBlock *GradientUtils::getOriginalFromNew(const BasicBlock *newinst) const { Value *GradientUtils::isOriginal(const Value *newinst) const { if (isa(newinst) || isa(newinst)) return const_cast(newinst); +#ifndef NDEBUG if (auto arg = dyn_cast(newinst)) { assert(arg->getParent() == newFunc); } if (auto inst = dyn_cast(newinst)) { assert(inst->getParent()->getParent() == newFunc); } +#endif auto found = newToOriginalFn.find(newinst); if (found == newToOriginalFn.end()) return nullptr; @@ -2561,11 +2556,13 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, return malloc; } +#ifndef NDEBUG if (auto CI = dyn_cast(malloc)) { if (auto F = CI->getCalledFunction()) { assert(F->getName() != "omp_get_thread_num"); } } +#endif if (malloc->getType()->isTokenTy()) { llvm::errs() << " oldFunc: " << *oldFunc << "\n"; @@ -3322,7 +3319,7 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { auto &DL = newFunc->getParent()->getDataLayout(); bool constantval = isConstantValue(orig_val) || - parseTBAA(I, DL, nullptr).Inner0().isIntegral(); + parseTBAA(I, DL, nullptr)[{-1}].isIntegral(); // TODO allow recognition of other types that could contain // pointers [e.g. {void*, void*} or <2 x i64> ] @@ -3467,8 +3464,9 @@ BasicBlock *GradientUtils::prepRematerializedLoopEntry(LoopContext &lc) { lctx, placeholder->getType(), placeholder->getName(), /*shouldFree*/ true); assert(cache); + Value *placeholder_tmp = placeholder; found = insert_or_assign( - scopeMap, (Value *&)placeholder, + scopeMap, placeholder_tmp, std::pair, LimitContext>(cache, lctx)); } auto cache = found->second.first; @@ -3830,7 +3828,7 @@ bool GradientUtils::legalRecompute(const Value *val, struct { Function *func; const LoopInfo &FLI; - } options[2] = {{newFunc, LI}, {oldFunc, OrigLI}}; + } options[2] = {{newFunc, LI}, {oldFunc, *OrigLI}}; for (const auto &tup : options) { if (parent->getParent() == tup.func) { for (auto &val : phi->incoming_values()) { @@ -3970,7 +3968,7 @@ bool GradientUtils::legalRecompute(const Value *val, const_cast(orig), [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && writesToMemoryReadBy( - OrigAA, TLI, + *OrigAA, TLI, /*maybeReader*/ const_cast(orig), /*maybeWriter*/ I)) { failed = true; @@ -3993,7 +3991,7 @@ bool GradientUtils::legalRecompute(const Value *val, } origStart = origStart->getNextNode(); } while (true); - if (OrigDT.dominates(origStart, const_cast(orig))) { + if (OrigDT->dominates(origStart, const_cast(orig))) { bool failed = false; allInstructionsBetween( @@ -4001,7 +3999,7 @@ bool GradientUtils::legalRecompute(const Value *val, const_cast(orig), [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && writesToMemoryReadBy( - OrigAA, TLI, + *OrigAA, TLI, /*maybeReader*/ const_cast(orig), /*maybeWriter*/ I)) { failed = true; @@ -4363,8 +4361,7 @@ DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::Value *orig, subretType = DIFFE_TYPE::DUP_ARG; shadowReturnUsed = true; } else { - if (!orig->getType()->isFPOrFPVectorTy() && - TR.query(orig).Inner0().isPossiblePointer()) { + if (!orig->getType()->isFPOrFPVectorTy() && TR.anyPointer(orig)) { if (DifferentialUseAnalysis::is_value_needed_in_reverse< QueryType::Shadow>(this, orig, cmode, notForAnalysis)) { subretType = DIFFE_TYPE::DUP_ARG; @@ -4401,8 +4398,7 @@ DIFFE_TYPE GradientUtils::getDiffeType(Value *v, bool foreignFunction) const { auto argType = v->getType(); - if (!argType->isFPOrFPVectorTy() && - (TR.query(v).Inner0().isPossiblePointer() || foreignFunction)) { + if (!argType->isFPOrFPVectorTy() && (TR.anyPointer(v) || foreignFunction)) { if (argType->isPointerTy()) { auto at = getBaseObject(v); if (auto arg = dyn_cast(at)) { @@ -4798,12 +4794,14 @@ void GradientUtils::setPtrDiffe(Instruction *orig, Value *ptr, Value *newval, SyncScope::ID syncScope, Value *mask, ArrayRef noAlias, ArrayRef scopes) { +#ifndef NDEBUG if (auto inst = dyn_cast(ptr)) { assert(inst->getParent()->getParent() == oldFunc); } if (auto arg = dyn_cast(ptr)) { assert(arg->getParent() == oldFunc); } +#endif Value *origptr = ptr; @@ -5078,12 +5076,14 @@ llvm::Value *GradientUtils::recursiveFAdd(llvm::IRBuilder<> &B, Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, bool nullShadow) { assert(oval); +#ifndef NDEBUG if (auto inst = dyn_cast(oval)) { assert(inst->getParent()->getParent() == oldFunc); } if (auto arg = dyn_cast(oval)) { assert(arg->getParent() == oldFunc); } +#endif if (isa(oval)) { return applyChainRule(oval->getType(), BuilderM, [&]() { return oval; }); @@ -5147,9 +5147,21 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, return applyChainRule(oval->getType(), BuilderM, rule); } - if (isConstantValue(oval) && !isa(oval) && - !isa(oval) && !isa(oval) && - !isa(oval)) { + bool shouldNullShadow = isConstantValue(oval); + if (shouldNullShadow) { + if (isa(oval) || isa(oval) || + isa(oval) || isa(oval)) { + shouldNullShadow = false; + auto orig = cast(oval); + if (knownRecomputeHeuristic.count(orig)) { + if (!knownRecomputeHeuristic[orig]) { + shouldNullShadow = true; + } + } + } + } + + if (shouldNullShadow) { // NOTE, this is legal and the correct resolution, however, our activity // analysis honeypot no longer exists @@ -5322,7 +5334,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, if (F && isMemFreeLibMFunction(F->getName())) { continue; } - if (llvm::isModOrRefSet(OrigAA.getModRefInfo(CI, Loc))) { + if (llvm::isModOrRefSet(OrigAA->getModRefInfo(CI, Loc))) { seen = true; llvm::errs() << " cannot shadow-inline global " << *oval << " due to " << *CI << "\n"; @@ -5477,10 +5489,21 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, if (!isa(md)) { llvm::errs() << *arg << "\n"; llvm::errs() << *md << "\n"; - assert(0 && "cannot compute with global variable that doesn't have " - "marked shadow global"); - report_fatal_error("cannot compute with global variable that doesn't " - "have marked shadow global (metadata incorrect type)"); + std::string s; + llvm::raw_string_ostream ss(s); + ss << "cannot compute with global variable that doesn't have marked " + "shadow global as mdtuple\n"; + ss << *arg << "\n"; + ss << " md: " << *md << "\n"; + if (CustomErrorHandler) { + return unwrap(CustomErrorHandler(ss.str().c_str(), wrap(arg), + ErrorType::NoShadow, this, nullptr, + wrap(&BuilderM))); + } else { + EmitFailure("InvertGlobal", BuilderM.getCurrentDebugLocation(), oldFunc, + ss.str()); + } + return UndefValue::get(getShadowType(arg->getType())); } auto md2 = cast(md); assert(md2->getNumOperands() == 1); @@ -5716,15 +5739,18 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *itval = nullptr; { auto tval = arg->getTrueValue(); - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(arg)[{-1}].isPossiblePointer() && !isa(tval) && !isa(tval) && isConstantValue(tval)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *arg << " const val: " << *tval; - itval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(tval), wrap(&bb))); + if (CustomErrorHandler) + itval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), + ErrorType::MixedActivityError, this, + wrap(tval), wrap(&bb))); + else + EmitWarning("MixedActivityError", *arg, ss.str()); } if (!itval) { itval = invertPointerM(tval, bb, nullShadow); @@ -5733,15 +5759,18 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *ifval = nullptr; { auto fval = arg->getFalseValue(); - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(arg)[{-1}].isPossiblePointer() && !isa(fval) && !isa(fval) && isConstantValue(fval)) { std::string str; raw_string_ostream ss(str); ss << "Mismatched activity for: " << *arg << " const val: " << *fval; - ifval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), - ErrorType::MixedActivityError, this, - wrap(fval), wrap(&bb))); + if (CustomErrorHandler) + ifval = unwrap(CustomErrorHandler(str.c_str(), wrap(arg), + ErrorType::MixedActivityError, this, + wrap(fval), wrap(&bb))); + else + EmitWarning("MixedActivityError", *arg, ss.str()); } if (!ifval) { ifval = invertPointerM(fval, bb, nullShadow); @@ -6085,7 +6114,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *preval = phi->getIncomingValue(j); Value *val = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(phi)[{-1}].isPossiblePointer() && !isa(preval) && !isa(preval) && isConstantValue(preval)) { @@ -6093,9 +6122,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, raw_string_ostream ss(str); ss << "Mismatched activity for: " << *phi << " const val: " << *preval; - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, this, - wrap(preval), wrap(&pre))); + if (CustomErrorHandler) + val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), + ErrorType::MixedActivityError, + this, wrap(preval), wrap(&pre))); + else + EmitWarning("MixedActivityError", *phi, ss.str()); } if (!val) { val = invertPointerM(preval, pre, nullShadow); @@ -6145,7 +6177,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, Value *preval = phi->getIncomingValue(i); Value *val = nullptr; - if (!EnzymeRuntimeActivityCheck && CustomErrorHandler && + if (!EnzymeRuntimeActivityCheck && TR.query(phi)[{-1}].isPossiblePointer() && !isa(preval) && !isa(preval) && isConstantValue(preval)) { @@ -6153,9 +6185,12 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM, raw_string_ostream ss(str); ss << "Mismatched activity for: " << *phi << " const val: " << *preval; - val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), - ErrorType::MixedActivityError, this, - wrap(preval), wrap(&pre))); + if (CustomErrorHandler) + val = unwrap(CustomErrorHandler(str.c_str(), wrap(phi), + ErrorType::MixedActivityError, + this, wrap(preval), wrap(&pre))); + else + EmitWarning("MixedActivityError", *phi, ss.str()); } if (!val) { val = invertPointerM(preval, pre, nullShadow); @@ -6345,9 +6380,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // + or because the loop nests share no ancestry bool loopLegal = true; - for (Loop *idx = OrigLI.getLoopFor(orig); idx != nullptr; + for (Loop *idx = OrigLI->getLoopFor(orig); idx != nullptr; idx = idx->getParentLoop()) { - for (Loop *fdx = OrigLI.getLoopFor(forwardBlock); fdx != nullptr; + for (Loop *fdx = OrigLI->getLoopFor(forwardBlock); fdx != nullptr; fdx = fdx->getParentLoop()) { if (idx == fdx) { loopLegal = false; @@ -6551,9 +6586,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // << "\n"; allInstructionsBetween( - OrigLI, orig2, origInst, [&](Instruction *I) -> bool { + *OrigLI, orig2, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -6567,12 +6602,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (auto ar1 = dyn_cast(scev1)) { if (auto ar2 = dyn_cast(scev2)) { - if (ar1->getStart() != OrigSE.getCouldNotCompute() && + if (ar1->getStart() != OrigSE->getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar1->getStepRecurrence(OrigSE) != - OrigSE.getCouldNotCompute() && - ar1->getStepRecurrence(OrigSE) == - ar2->getStepRecurrence(OrigSE)) { + ar1->getStepRecurrence(*OrigSE) != + OrigSE->getCouldNotCompute() && + ar1->getStepRecurrence(*OrigSE) == + ar2->getStepRecurrence(*OrigSE)) { LoopContext l1; getContext(ar1->getLoop()->getHeader(), l1); @@ -6600,7 +6635,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); + auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); auto Arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch(); @@ -6608,7 +6643,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Arch == Triple::amdgcn ? (int)AMDGPU::HSAMD::AddressSpaceQualifier::Local : 3; - if (EnzymeSharedForward && scev1 != OrigSE.getCouldNotCompute() && + if (EnzymeSharedForward && scev1 != OrigSE->getCouldNotCompute() && cast(orig_liobj->getType())->getAddressSpace() == SharedAddrSpace) { Value *resultValue = nullptr; @@ -6617,7 +6652,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, assert(pair.first->getType() == pair.second->getType()); newavail[pair.first] = pair.second; } - allDomPredecessorsOf(origInst, OrigDT, [&](Instruction *pred) { + allDomPredecessorsOf(origInst, *OrigDT, [&](Instruction *pred) { if (auto SI = dyn_cast(pred)) { // auto NewSI = cast(getNewFromOriginal(SI)); auto si2obj = getBaseObject(SI->getPointerOperand()); @@ -6628,10 +6663,10 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool lastStore = true; bool interveningSync = false; allInstructionsBetween( - OrigLI, SI, origInst, [&](Instruction *potentialAlias) { + *OrigLI, SI, origInst, [&](Instruction *potentialAlias) { if (!potentialAlias->mayWriteToMemory()) return false; - if (!writesToMemoryReadBy(OrigAA, TLI, origInst, + if (!writesToMemoryReadBy(*OrigAA, TLI, origInst, potentialAlias)) return false; @@ -6649,7 +6684,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (mid == SI) return false; - if (!writesToMemoryReadBy(OrigAA, TLI, origInst, + if (!writesToMemoryReadBy(*OrigAA, TLI, origInst, mid)) { return false; } @@ -6676,16 +6711,16 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (!lastStore) return false; - auto scev2 = OrigSE.getSCEV(SI->getPointerOperand()); + auto scev2 = OrigSE->getSCEV(SI->getPointerOperand()); bool legal = scev1 == scev2; if (auto ar2 = dyn_cast(scev2)) { if (auto ar1 = dyn_cast(scev1)) { - if (ar2->getStart() != OrigSE.getCouldNotCompute() && + if (ar2->getStart() != OrigSE->getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar2->getStepRecurrence(OrigSE) != - OrigSE.getCouldNotCompute() && - ar1->getStepRecurrence(OrigSE) == - ar2->getStepRecurrence(OrigSE)) { + ar2->getStepRecurrence(*OrigSE) != + OrigSE->getCouldNotCompute() && + ar1->getStepRecurrence(*OrigSE) == + ar2->getStepRecurrence(*OrigSE)) { LoopContext l1; getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), @@ -6738,15 +6773,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, ValueToValueMapTy ThreadLookup; bool legal = true; for (size_t i = 0; i < svals.size(); i++) { - auto ss = OrigSE.getSCEV(svals[i]); - auto ls = OrigSE.getSCEV(lvals[i]); + auto ss = OrigSE->getSCEV(svals[i]); + auto ls = OrigSE->getSCEV(lvals[i]); if (cast(ss->getType())->getBitWidth() > cast(ls->getType())->getBitWidth()) { - ls = OrigSE.getZeroExtendExpr(ls, ss->getType()); + ls = OrigSE->getZeroExtendExpr(ls, ss->getType()); } if (cast(ss->getType())->getBitWidth() < cast(ls->getType())->getBitWidth()) { - ls = OrigSE.getTruncateExpr(ls, ss->getType()); + ls = OrigSE->getTruncateExpr(ls, ss->getType()); } if (ls != ss) { if (auto II = dyn_cast(svals[i])) { @@ -6833,23 +6868,23 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto origPH = cast_or_null(isOriginal(ctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { goto noSpeedCache; } Instruction *origTerm = origPH->getTerminator(); if (!origTerm) - llvm::errs() << *origTerm << "\n"; + llvm::errs() << *origPH << "\n"; assert(origTerm); IRBuilder<> OB(origTerm); LoadInst *tmpload = OB.CreateLoad(AT, orig_liobj, "'tmpload"); bool failed = false; allInstructionsBetween( - OrigLI, &*origTerm, origInst, + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ tmpload, /*maybeWriter*/ I)) { failed = true; @@ -6867,15 +6902,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; auto origPH = cast_or_null(isOriginal(nctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { break; } Instruction *origTerm = origPH->getTerminator(); allInstructionsBetween( - OrigLI, &*origTerm, origInst, + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ tmpload, /*maybeWriter*/ I)) { failed = true; @@ -6967,7 +7002,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); + auto scev1 = OrigSE->getSCEV(origInst->getPointerOperand()); // Store in memcpy opt Value *lim = nullptr; BasicBlock *ctx = nullptr; @@ -6975,7 +7010,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Value *offset = nullptr; if (auto ar1 = dyn_cast(scev1)) { if (auto step = - dyn_cast(ar1->getStepRecurrence(OrigSE))) { + dyn_cast(ar1->getStepRecurrence(*OrigSE))) { if (step->getAPInt() != loadSize) goto noSpeedCache; @@ -6992,7 +7027,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto origPH = cast_or_null(isOriginal(ctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { goto noSpeedCache; } @@ -7011,7 +7046,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, SmallVector InsertedInstructions; { SCEVExpander OrigExp( - OrigSE, ctx->getParent()->getParent()->getDataLayout(), + *OrigSE, ctx->getParent()->getParent()->getDataLayout(), "enzyme"); OrigExp.setInsertPoint( @@ -7032,7 +7067,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // instructions. llvm::stable_sort(InsertedInstructions, [this](Instruction *A, Instruction *B) { - return OrigDT.dominates(A, B); + return OrigDT->dominates(A, B); }); for (auto a : InsertedInstructions) { assert(!isa(a)); @@ -7040,6 +7075,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap, /*scope*/ nullptr, /*cache*/ false)); assert(uw->getType() == a->getType()); +#ifndef NDEBUG for (size_t i = 0; i < uw->getNumOperands(); i++) { auto op = uw->getOperand(i); if (auto arg = dyn_cast(op)) @@ -7047,6 +7083,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, else if (auto inst = dyn_cast(op)) assert(inst->getParent()->getParent() == newFunc); } +#endif available[a] = uw; unwrappedLoads.erase(cast(uw)); } @@ -7063,7 +7100,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, available.clear(); for (auto I : llvm::reverse(InsertedInstructions)) { assert(I->getNumUses() == 0); - OrigSE.forgetValue(I); + OrigSE->forgetValue(I); I->eraseFromParent(); } #endif @@ -7076,9 +7113,9 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; allInstructionsBetween( - OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -7100,14 +7137,14 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, bool failed = false; auto origPH = cast_or_null(isOriginal(nctx)); assert(origPH); - if (OrigPDT.dominates(origPH, origInst->getParent())) { + if (OrigPDT->dominates(origPH, origInst->getParent())) { break; } Instruction *origTerm = origPH->getTerminator(); allInstructionsBetween( - OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { + *OrigLI, &*origTerm, origInst, [&](Instruction *I) -> bool { if (I->mayWriteToMemory() && - writesToMemoryReadBy(OrigAA, TLI, + writesToMemoryReadBy(*OrigAA, TLI, /*maybeReader*/ origInst, /*maybeWriter*/ I)) { failed = true; @@ -7268,7 +7305,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, AllocaInst *cache = createCacheForScope( lctx, inst->getType(), inst->getName(), /*shouldFree*/ true); assert(cache); - insert_or_assign(scopeMap, (Value *&)inst, + Value *inst_tmp = inst; + insert_or_assign(scopeMap, inst_tmp, std::pair, LimitContext>( cache, lctx)); } @@ -7342,9 +7380,11 @@ void GradientUtils::branchToCorrespondingTarget( if (replacePHIs->size() == 0) return; +#ifndef NDEBUG for (auto x : *replacePHIs) { assert(targetToPreds.find(x.first) != targetToPreds.end()); } +#endif } if (targetToPreds.size() == 1) { @@ -7900,7 +7940,7 @@ void GradientUtils::computeMinCache() { for (BasicBlock &BB : *oldFunc) { if (notForAnalysis.count(&BB)) continue; - auto L = OrigLI.getLoopFor(&BB); + auto L = OrigLI->getLoopFor(&BB); auto invariant = [&](Value *V) { if (isa(V)) @@ -7908,20 +7948,20 @@ void GradientUtils::computeMinCache() { if (isa(V)) return true; if (auto I = dyn_cast(V)) { - if (!L->contains(OrigLI.getLoopFor(I->getParent()))) + if (!L->contains(OrigLI->getLoopFor(I->getParent()))) return true; } return false; }; for (Instruction &I : BB) { if (auto PN = dyn_cast(&I)) { - if (!OrigLI.isLoopHeader(&BB)) + if (!OrigLI->isLoopHeader(&BB)) continue; if (PN->getType()->isIntegerTy()) { bool legal = true; SmallPtrSet Increment; for (auto B : PN->blocks()) { - if (OrigLI.getLoopFor(B) == L) { + if (OrigLI->getLoopFor(B) == L) { if (auto BO = dyn_cast( PN->getIncomingValueForBlock(B))) { if (BO->getOpcode() == BinaryOperator::Add) { @@ -8009,7 +8049,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(&BB); L != nullptr; + for (Loop *L = OrigLI->getLoopFor(&BB); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8042,16 +8082,16 @@ void GradientUtils::computeMinCache() { todo.pop_front(); if (Intermediates.count(V)) continue; - if (!DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal>(this, V, minCutMode, FullSeen, - notForAnalysis)) { + bool multiLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< + QueryType::Primal>(this, V, minCutMode, FullSeen, notForAnalysis); + if (!multiLevel) { continue; } if (!Recomputes.count(V)) { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(cast(V)->getParent()); + for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8064,27 +8104,21 @@ void GradientUtils::computeMinCache() { } } Intermediates.insert(V); - if (DifferentialUseAnalysis::is_value_needed_in_reverse< - QueryType::Primal, /*OneLevel*/ true>( - this, V, minCutMode, OneLevelSeen, notForAnalysis)) { + bool singleLevel = DifferentialUseAnalysis::is_value_needed_in_reverse< + QueryType::Primal, /*OneLevel*/ true>(this, V, minCutMode, + OneLevelSeen, notForAnalysis); + if (singleLevel) { Required.insert(V); } else { - for (auto V2 : V->users()) { - if (auto Inst = dyn_cast(V2)) - for (auto pair : rematerializableAllocations) { - if (pair.second.stores.count(Inst)) { - todo.push_back(pair.first); - } - } - todo.push_back(V2); - } + DifferentialUseAnalysis::forEachDifferentialUser( + [&](Value *V2) { todo.push_back(V2); }, this, V); } } SetVector MinReq; DifferentialUseAnalysis::minCut(oldFunc->getParent()->getDataLayout(), - OrigLI, Recomputes, Intermediates, Required, - MinReq, rematerializableAllocations, TLI); + *OrigLI, Recomputes, Intermediates, + Required, MinReq, this, TLI); SmallPtrSet NeedGraph; for (Value *V : MinReq) NeedGraph.insert(V); @@ -8113,7 +8147,7 @@ void GradientUtils::computeMinCache() { ValueToValueMapTy Available2; for (auto a : Available) Available2[a.first] = a.second; - for (Loop *L = OrigLI.getLoopFor(cast(V)->getParent()); + for (Loop *L = OrigLI->getLoopFor(cast(V)->getParent()); L != nullptr; L = L->getParentLoop()) { for (auto v : LoopAvail[L]) { Available2[v] = v; @@ -8192,11 +8226,13 @@ void GradientUtils::forceActiveDetection() { bool GradientUtils::isConstantValue(Value *val) const { if (auto inst = dyn_cast(val)) { + (void)inst; assert(inst->getParent()->getParent() == oldFunc); return ATA->isConstantValue(TR, val); } if (auto arg = dyn_cast(val)) { + (void)arg; assert(arg->getParent() == oldFunc); return ATA->isConstantValue(TR, val); } @@ -8761,13 +8797,13 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { } // Find the outermost loop of all stores, and the allocation/lifetime - Loop *outer = OrigLI.getLoopFor(V->getParent()); + Loop *outer = OrigLI->getLoopFor(V->getParent()); if (LifetimeStarts.size() == 1) { - outer = OrigLI.getLoopFor((*LifetimeStarts.begin())->getParent()); + outer = OrigLI->getLoopFor((*LifetimeStarts.begin())->getParent()); } for (auto S : stores) { - outer = getAncestor(outer, OrigLI.getLoopFor(S->getParent())); + outer = getAncestor(outer, OrigLI->getLoopFor(S->getParent())); } // May now read pointers for storing into other pointers. Therefore we @@ -8781,8 +8817,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, - outer)) { + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI, + res, outer)) { EmitWarning("NotPromotable", *LI, " Could not promote shadow allocation ", *V, " due to pointer load ", *LI, @@ -8836,7 +8872,7 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI, res, + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, LI, res, outer)) { EmitWarning("NotPromotable", *LI, " Could not promote allocation ", *V, " due to load ", *LI, @@ -8852,8 +8888,8 @@ void GradientUtils::computeForwardingProperties(Instruction *V) { SmallVector results; mayExecuteAfter(results, LI.loadCall, storingOps, outer); for (auto res : results) { - if (overwritesToMemoryReadBy(OrigAA, TLI, SE, OrigLI, OrigDT, LI.loadCall, - res, outer)) { + if (overwritesToMemoryReadBy(*OrigAA, TLI, SE, *OrigLI, *OrigDT, + LI.loadCall, res, outer)) { EmitWarning("NotPromotable", *LI.loadCall, " Could not promote allocation ", *V, " due to load-like call ", *LI.loadCall, @@ -8907,6 +8943,7 @@ void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) { // Check that the replacement doesn't already exist in the mapping // thereby resulting in a conflict. +#ifndef NDEBUG { auto found = newToOriginalFn.find(A); if (found != newToOriginalFn.end()) { @@ -8914,6 +8951,7 @@ void GradientUtils::replaceAWithB(Value *A, Value *B, bool storeInCache) { assert(foundB == newToOriginalFn.end()); } } +#endif CacheUtility::replaceAWithB(A, B, storeInCache); } @@ -9064,7 +9102,7 @@ void GradientUtils::computeGuaranteedFrees() { bool hasPDFree = false; if (dc->getParent() == CI->getParent() || - OrigPDT.dominates(CI->getParent(), dc->getParent())) { + OrigPDT->dominates(CI->getParent(), dc->getParent())) { hasPDFree = true; } @@ -9177,10 +9215,12 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, tofree = builder.CreateIntToPtr(tofree, getInt8PtrTy(tofree->getContext())); llvm::LibFunc libfunc; - if (allocationfn == "calloc" || allocationfn == "malloc") { + if (allocationfn == "calloc" || allocationfn == "malloc" || + allocationfn == "_mlir_memref_to_llvm_alloc") { libfunc = LibFunc_malloc; } else { bool res = TLI.getLibFunc(allocationfn, libfunc); + (void)res; assert(res && "ought find known allocation fn"); } @@ -9247,6 +9287,8 @@ llvm::CallInst *freeKnownAllocation(llvm::IRBuilder<> &builder, if (freename != "free") llvm_unreachable("illegal free"); } + if (allocationfn == "_mlir_memref_to_llvm_alloc") + freename = "_mlir_memref_to_llvm_free"; Type *VoidTy = Type::getVoidTy(tofree->getContext()); Type *IntPtrTy = getInt8PtrTy(tofree->getContext()); @@ -9316,6 +9358,16 @@ bool GradientUtils::needsCacheWholeAllocation( if (idx < CI->getNumArgOperands()) #endif { + + // Calling a non-empty function with a julia base object, this is fine. + // as GC will deal with any issues with. + if (auto PT = dyn_cast(CI->getArgOperand(idx)->getType())) + if (PT->getAddressSpace() == 10) + if (EnzymeJuliaAddrLoad) + if (auto F = getFunctionFromCall(CI)) + if (!F->empty()) + continue; + if (isNoCapture(CI, idx)) continue; diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 6a8c9fd61b9e..98cae4145639 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -128,10 +128,10 @@ class GradientUtils : public CacheUtility { DerivativeMode mode; llvm::Function *oldFunc; llvm::ValueMap invertedPointers; - llvm::DominatorTree &OrigDT; - llvm::PostDominatorTree &OrigPDT; - llvm::LoopInfo &OrigLI; - llvm::ScalarEvolution &OrigSE; + llvm::DominatorTree *OrigDT; + llvm::PostDominatorTree *OrigPDT; + llvm::LoopInfo *OrigLI; + llvm::ScalarEvolution *OrigSE; /// (Original) Blocks which dominate all returns llvm::SmallPtrSet BlocksDominatingAllReturns; @@ -353,7 +353,7 @@ class GradientUtils : public CacheUtility { } public: - llvm::AAResults &OrigAA; + llvm::AAResults *OrigAA; TypeAnalysis &TA; TypeResults TR; bool omp; @@ -601,11 +601,13 @@ class GradientUtils : public CacheUtility { llvm::ArrayRef diffs, llvm::IRBuilder<> &Builder, Func rule) { if (width > 1) { +#ifndef NDEBUG for (auto diff : diffs) { assert(diff); assert(llvm::cast(diff->getType())->getNumElements() == width); } +#endif llvm::Type *wrappedType = llvm::ArrayType::get(diffType, width); llvm::Value *res = llvm::UndefValue::get(wrappedType); for (unsigned int i = 0; i < getWidth(); ++i) { diff --git a/enzyme/Enzyme/InstructionDerivatives.td b/enzyme/Enzyme/InstructionDerivatives.td index 53fad8160993..e309055c34d2 100644 --- a/enzyme/Enzyme/InstructionDerivatives.td +++ b/enzyme/Enzyme/InstructionDerivatives.td @@ -162,12 +162,20 @@ def CFNeg : SubRoutine<(Op (Op $re, $im):$z), (FNeg $re), (FNeg $im) )>; + +def Conj : SubRoutine<(Op (Op $re, $im):$z), + (ArrayRet + $re, + (FNeg $im) + )>; + def CFExp : SubRoutine<(Op (Op $re, $im):$z), (ArrayRet (FMul (FExp $re):$exp, (FCos $im)), (FMul $exp, (FSin $im)) )>; + // Same function as the one being called def SameFunc { } @@ -186,6 +194,12 @@ class PrependArgTypesFunc pretys_> { list pretys = pretys_; } +// Set return to arg[0] +// Same argument types +class ArgAsRetTypesFunc { + string name = name_; +} + // Specify that a given argument is inactive, aka not differentiable // By default this argument tells Enzyme that it must always be inactive // from the function semantics. @@ -204,7 +218,9 @@ def AssertingInactiveArg : InactiveArgSpec { class GlobalExpr : Operation{ string value = val; } -def MantissaMaskOfReturn : GlobalExprisX86_FP80Ty()) {\n" " tsize = 80;\n" " high = tsize - 1;\n" -" low = high - 15;\n" +" low = high - 16;\n" + // x86_fp80 has only 15 exponent bits, but we must also + // retain the most-significant bit of the mantissa as + // there is no implicit leading 1. " } else if (ty->isFP128Ty()) {\n" " tsize = 128;\n" " high = tsize - 1;\n" @@ -311,13 +330,18 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; - def : CallPattern<(Op $x), ["tanhf"], [(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"coshf">), [ReadNone,NoUnwind]> $x):$c, $c))], (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["tanhl"], + [(FDiv (DiffeRet), (FMul(Call<(SameTypesFunc<"coshl">), [ReadNone,NoUnwind]> $x):$c, $c))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["cosh"], @@ -331,6 +355,12 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["coshl"], + [(FMul (DiffeRet), (Call<(SameTypesFunc<"sinhl">), [ReadNone,NoUnwind]> $x))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["sinh"], @@ -344,6 +374,33 @@ def : CallPattern<(Op $x), (ForwardFromSummedReverse), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["sinhl"], + [(FMul (DiffeRet), (Call<(SameTypesFunc<"coshl">), [ReadNone,NoUnwind]> $x))], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["asinh", "asinhf", "asinhl", "__nv_asinh", "__nv_asinhf"], + [(FDiv (DiffeRet), (Intrinsic<"sqrt"> (FAdd (FMul $x, $x), (ConstantFP<"1.0"> $x))) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["acosh", "acoshf", "acoshl", "__nv_acosh", "__nv_acoshf"], + [(FDiv (DiffeRet), (Intrinsic<"sqrt"> (FSub (FMul $x, $x), (ConstantFP<"1.0"> $x))) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["atanh", "atanhf", "atanhl", "__nv_atanh", "__nv_atanhf"], + [(FDiv (DiffeRet), (FSub (ConstantFP<"1.0"> $x), (FMul $x, $x)) )] , + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; def : CallPattern<(Op $x), ["exp10"], @@ -428,7 +485,7 @@ def : CallPattern<(Op $x), >; def : CallPattern<(Op $x, $y), - ["fmod", "fmodf", "fmodl"], + ["fmod", "fmodf", "fmodl", "__nv_fmod", "__nv_fmodf", "__nv_fmodl"], [ (DiffeRet), (CheckedMul (DiffeRet), (FNeg (Intrinsic<"copysign"> (Intrinsic<"floor"> (Intrinsic<"fabs"> (FDiv $x, $y):$div)), $div))) @@ -437,6 +494,16 @@ def : CallPattern<(Op $x, $y), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x, $integral_part_ptr), + ["modf", "modff", "modfl"], + [ + (DiffeRet), + (InactiveArg) + ], + (ForwardFromSummedReverse), + [ReadOnly, NoUnwind] + >; + def : CallPattern<(Op $x), ["__fd_sincos_1", "__fd_sincos_1f", "__fd_sincos_1l"], [ @@ -461,6 +528,30 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; +def : CallPattern<(Op $x), + ["logabsgamma"], + [ + ( + ArrayRet (FMul (Call<(ArgAsRetTypesFunc<"digamma">), [ReadNone,NoUnwind]> $x), (DiffeRet) ), + (InactiveArg) + ) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + +def : CallPattern<(Op $x), + ["logabsgammaf"], + [ + ( + ArrayRet (FMul (Call<(ArgAsRetTypesFunc<"digammaf">), [ReadNone,NoUnwind]> $x), (DiffeRet) ), + (InactiveArg) + ) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : CallPattern<(Op $x), ["sinpi", "sinpif", "sinpil", "cospi", "cospif", "cospil"], [ @@ -534,7 +625,7 @@ def : CallPattern<(Op $n, $x), >; def : CallPattern<(Op $x), - ["erf"], + ["erf","erff","erfl"], [ (FMul (DiffeRet), (FMul (ConstantFP<"1.1283791670955125738961589031215451716881012586580"> $x), (Intrinsic<"exp"> (FNeg (FMul $x, $x))))) ], @@ -550,7 +641,7 @@ def : CallPattern<(Op $x), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x), - ["erfc"], + ["erfc","erfcf","erfcl"], [ (FMul (DiffeRet), (FMul (ConstantFP<"-1.1283791670955125738961589031215451716881012586580"> $x), (Intrinsic<"exp"> (FNeg (FMul $x, $x))))) ], @@ -564,30 +655,30 @@ def ToStruct2 : SubRoutine<(Op (Op $re, $im):$z), def : CallPattern<(Op $x, $tbd), ["Faddeeva_erf"], [ - (ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), + (ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x))))))), (InactiveArg) // relerr ], - (ForwardFromSummedReverse), + (ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x, $tbd), ["Faddeeva_erfi"], [ - (ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x))))), + (ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x)))))), (InactiveArg) // relerr ], - (ForwardFromSummedReverse), + (ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFMul $x, $x))))), [ReadNone, NoUnwind] >; def : CallPattern<(Op $x, $tbd), ["Faddeeva_erfc"], [ - (ToStruct2 (CFMul (DiffeRet), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), + (ToStruct2 (Conj (CFMul (Conj (DiffeRet)), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x))))))), (InactiveArg) // relerr ], - (ForwardFromSummedReverse), + (ToStruct2 (CFMul (Shadow $x), (CFMul (ConstantCFP<"-1.1283791670955125738961589031215451716881012586580","0"> $x), (CFExp (CFNeg (CFMul $x, $x)))))), [ReadNone, NoUnwind] >; @@ -692,7 +783,7 @@ def : CallPattern<(Op $x, $expout), (DiffeRet), (FMul (BitCast - (And (MantissaMaskOfReturn):$mask, (BitCast $x, (TypeOf $mask)) ), + (And (MantissaMaskOfReturnForFrexp):$mask, (BitCast $x, (TypeOf $mask)) ), (TypeOf $x) ), (ConstantFP<"2"> $x) @@ -740,6 +831,16 @@ def : CallPattern<(Op (Op $x, $y):$z), [ReadNone, NoUnwind] >; +def : CallPattern<(Op (Op $x, $y):$z), + ["cmplx_inv"], + [ + // Reverse mode needs to return the conjugate + (Conj (CFDiv (CFNeg (Conj (DiffeRet))), (CFMul $z, $z))), + ], + (CFDiv (CFNeg (Shadow $z)), (CFMul $z, $z)), + [ReadNone, NoUnwind] + >; + def : IntrPattern<(Op $x), [["sin"]], [(FMul (DiffeRet), (Intrinsic<"cos"> $x))] , @@ -809,6 +910,16 @@ def : IntrPattern<(Op $x, $y), (Select (FCmpOLT $x, $y), (SelectIfActive $y, (Shadow $y), (Zero $y)), (SelectIfActive $x, (Shadow $x), (Zero $x))) >; +def : CallPattern<(Op $x, $y), + ["fdim", "fdimf", "fdiml"], + [ + (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $x), (DiffeRet)), + (Select (FCmpOLT $x, $y), (ConstantFP<"0"> $y), (FNeg (DiffeRet))) + ], + (ForwardFromSummedReverse), + [ReadNone, NoUnwind] + >; + def : IntrPattern<(Op $x), [["fabs"]], [ diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index 5fd3992cfad5..d18ecc346802 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -49,6 +49,8 @@ static inline bool isAllocationFunction(const llvm::StringRef name, return true; if (name == "calloc" || name == "malloc") return true; + if (name == "_mlir_memref_to_llvm_alloc") + return true; if (name == "swift_allocObject") return true; if (name == "__rust_alloc" || name == "__rust_alloc_zeroed") @@ -123,6 +125,8 @@ static inline bool isDeallocationFunction(const llvm::StringRef name, if (!TLI.getLibFunc(name, libfunc)) { if (name == "free") return true; + if (name == "_mlir_memref_to_llvm_free") + return true; if (name == "__rust_dealloc") return true; if (name == "swift_release") @@ -209,11 +213,7 @@ static inline void zeroKnownAllocation(llvm::IRBuilder<> &bb, } if (funcName == "enzyme_allocator") { auto index = getAllocationIndexFromCall(orig); -#if LLVM_VERSION_MAJOR >= 16 - allocSize = argValues[index.value()]; -#else - allocSize = argValues[index.getValue()]; -#endif + allocSize = argValues[*index]; } Value *dst_arg = toZero; diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index c1a471277915..1415559b3f6b 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -3,7 +3,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" @@ -19,7 +18,9 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/Support/ModRef.h" -const char *KnownInactiveFunctionsStartingWith[] = { +#include "Interfaces/AutoDiffOpInterface.h" + +static const char *KnownInactiveFunctionsStartingWith[] = { "f90io", "$ss5print", "_ZTv0_n24_NSoD", //"1Ev, 0Ev @@ -28,11 +29,11 @@ const char *KnownInactiveFunctionsStartingWith[] = { "_ZNSaIcEC1Ev", }; -const char *KnownInactiveFunctionsContains[] = { +static const char *KnownInactiveFunctionsContains[] = { "__enzyme_float", "__enzyme_double", "__enzyme_integer", "__enzyme_pointer"}; -const std::set InactiveGlobals = { +static const std::set InactiveGlobals = { "ompi_request_null", "ompi_mpi_double", "ompi_mpi_comm_world", "stderr", "stdout", "stdin", "_ZSt3cin", "_ZSt4cout", "_ZSt5wcout", "_ZSt4cerr", "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", @@ -56,7 +57,7 @@ const std::set InactiveGlobals = { "_ZTVN10__cxxabiv117__class_type_infoE", "_ZTVN10__cxxabiv121__vmi_class_type_infoE"}; -const std::map MPIInactiveCommAllocators = { +static const std::map MPIInactiveCommAllocators = { {"MPI_Graph_create", 5}, {"MPI_Comm_split", 2}, {"MPI_Intercomm_create", 6}, @@ -74,7 +75,7 @@ const std::map MPIInactiveCommAllocators = { // Instructions which themselves are inactive // the returned value, however, may still be active -const std::set KnownInactiveFunctionInsts = { +static const std::set KnownInactiveFunctionInsts = { "__dynamic_cast", "_ZSt18_Rb_tree_decrementPKSt18_Rb_tree_node_base", "_ZSt18_Rb_tree_incrementPKSt18_Rb_tree_node_base", @@ -83,7 +84,7 @@ const std::set KnownInactiveFunctionInsts = { "jl_ptr_to_array", "jl_ptr_to_array_1d"}; -const std::set KnownInactiveFunctions = { +static const std::set KnownInactiveFunctions = { "abort", "time", "memcmp", @@ -164,7 +165,7 @@ const std::set KnownInactiveFunctions = { "logbl", }; -const char *DemangledKnownInactiveFunctionsStartingWith[] = { +static const char *DemangledKnownInactiveFunctionsStartingWith[] = { // TODO this returns allocated memory and thus can be an active value // "std::allocator", "std::string", @@ -247,6 +248,8 @@ static Operation *getFunctionFromCall(CallOpInterface iface) { return SymbolTable::lookupNearestSymbolFrom(iface.getOperation(), symbol); } +constexpr bool EnzymePrintActivity = false; + /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode bool mlir::enzyme::ActivityAnalyzer::isFunctionArgumentConstant( @@ -464,12 +467,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // The return instruction doesn't impact activity (handled specifically // during adjoint generation) - if (isa(I)) + if (I->hasTrait()) return true; + if (auto ifaceOp = dyn_cast(I)) { + if (ifaceOp.isInactive()) { + return true; + } + } + // Branch, unreachable, and previously computed constants are inactive - if (isa(I) /*|| isa(I)*/ || - ConstantOperations.contains(I)) { + if (/*|| isa(I)*/ ConstantOperations.contains(I)) { return true; } @@ -479,9 +487,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, } if (notForAnalysis.count(I->getBlock())) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as dominates unreachable " << *I - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as dominates unreachable " << *I + << "\n"; InsertConstantOperation(TR, I); return true; } @@ -489,14 +497,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (auto CI = dyn_cast(I)) { // TODO(PR #904): This needs to be put into the enzyme dialect if (CI->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active " << *I << "\n"; ActiveOperations.insert(I); return false; } if (CI->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -504,14 +512,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (called) { if (called->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active " << *I << "\n"; ActiveOperations.insert(I); return false; } if (called->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -592,12 +600,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // *I // << "\n"; - if (isa(I)) { - InsertConstantOperation(TR, I); - } - // if (auto II = dyn_cast(I)) { // switch (II->getIntrinsicID()) { // case Intrinsic::nvvm_barrier0: @@ -683,11 +685,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, // If all returned values constant otherwise, the operation is inactive if (llvm::all_of(I->getResults(), [&](Value v) { return isConstantValue(TR, v); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from known constant - // non-writing " - // "instruction " - // << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction from known constant non-writing " + "instruction " + << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -710,9 +711,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, if (llvm::all_of(I->getResults(), [&](Value val) { return isValueInactiveFromUsers(TR, val, UseActivity::None); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction[" << (int)directions - // << "] from users instruction " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction[" << (int)directions + << "] from users instruction " << *I << "\n"; InsertConstantOperation(TR, I); return true; } @@ -724,9 +725,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, return DownHypothesis->isValueInactiveFromUsers( TR, val, UseActivity::None); })) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction[" << (int)directions - // << "] from users instruction " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction[" << (int)directions + << "] from users instruction " << *I << "\n"; InsertConstantOperation(TR, I); insertConstantsFrom(TR, *DownHypothesis); return true; @@ -748,65 +749,42 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantOperation(MTypeResults const &TR, new mlir::enzyme::ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantOperations.insert(I); assert(directions & UP); - if (UpHypothesis->isOperationInactiveFromOrigin(TR, I)) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from origin " - // "instruction " - // << *I << "\n"; + SmallPtrSet toredo; + if (UpHypothesis->isOperationInactiveFromOrigin(TR, I, std::nullopt, + &toredo)) { + if (EnzymePrintActivity) + llvm::errs() << " constant instruction from origin " + "instruction " + << *I << "\n"; InsertConstantOperation(TR, I); insertConstantsFrom(TR, *UpHypothesis); if (DownHypothesis) insertConstantsFrom(TR, *DownHypothesis); return true; } else if (directions == (UP | DOWN)) { - // TODO: what does this mean for interfaces? - if (isa< - // clang-format off - LLVM::LoadOp, - LLVM::StoreOp, - // Integer binary ops. - LLVM::AddOp, - LLVM::SubOp, - LLVM::MulOp, - LLVM::UDivOp, - LLVM::SDivOp, - LLVM::URemOp, - LLVM::SRemOp, - LLVM::AndOp, - LLVM::OrOp, - LLVM::XOrOp, - LLVM::ShlOp, - LLVM::LShrOp, - LLVM::AShrOp, - // Float binary ops. - LLVM::FAddOp, - LLVM::FSubOp, - LLVM::FMulOp, - LLVM::FDivOp, - LLVM::FRemOp, - LLVM::FNegOp - // clang-format on - >(I)) { - for (Value operand : I->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - ReEvaluateOpIfInactiveValue[operand].insert(I); - } - } + for (Value operand : toredo) { + ReEvaluateOpIfInactiveValue[operand].insert(I); } } } // Otherwise we must fall back and assume this instruction to be active. ActiveOperations.insert(I); - // if (EnzymePrintActivity) - // llvm::errs() << "couldnt decide fallback as nonconstant instruction(" - // << (int)directions << "):" << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "couldnt decide fallback as nonconstant instruction(" + << (int)directions << "):" << *I << "\n"; if (noActiveWrite && (directions == (UP | DOWN))) for (Value result : I->getResults()) ReEvaluateOpIfInactiveValue[result].insert(I); return false; } +static bool isFunctionReturn(Operation *op) { + if (!op->hasTrait()) + return false; + return dyn_cast(op->getParentOp()); +} + static bool isValuePotentiallyUsedAsPointer(Value val) { std::deque todo = {val}; SmallPtrSet seen; @@ -817,7 +795,39 @@ static bool isValuePotentiallyUsedAsPointer(Value val) { continue; seen.insert(cur); for (Operation *user : cur.getUsers()) { - if (isa(user)) + if (isa(user->getParentOp())) + if (auto termIface = + dyn_cast(user)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + auto parentOp = termIface->getParentOp(); + for (auto &successor : successors) { + OperandRange operandRange = + termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) { + if (prev == cur) { + todo.push_back(post); + } + } + } + continue; + } + if (auto iface = dyn_cast(user)) { + for (auto &op : user->getOpOperands()) + if (op.get() == cur) + if (auto blk = + iface.getSuccessorBlockArgument(op.getOperandNumber())) + todo.push_back(*blk); + continue; + } + if (isFunctionReturn(user)) return true; // The operation is known not to read or write memory. if (isa(user) && @@ -828,10 +838,9 @@ static bool isValuePotentiallyUsedAsPointer(Value val) { } continue; } - // if (EnzymePrintActivity) - // llvm::errs() << " VALUE potentially used as pointer " << *val << " by - // " - // << *u << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " VALUE potentially used as pointer " << val << " by " + << *user << "\n"; return true; } } @@ -889,7 +898,134 @@ static FunctionOpInterface getFunctionIfArgument(Value value) { return dyn_cast(block->getParentOp()); } -// TODO: move the extraction based on dataflow here. +// For a given instruction, determine whether it is a terminator which +// controls dataflow out, and if so return all users either in results +// or blockarguments +static std::optional> +getPotentialTerminatorUsers(Operation *op, Value parent) { + auto block = op->getBlock(); + + if (block->getTerminator() != op) + return {}; + if (isFunctionReturn(op)) + return {}; + + SmallVector results; + + if (isa(op->getParentOp())) + if (auto termIface = dyn_cast(op)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + auto parentOp = termIface->getParentOp(); + SmallVector results; + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) { + if (prev == parent) { + results.push_back(post); + } + } + } + return std::move(results); + } + if (auto iface = dyn_cast(op)) { + for (auto &operand : op->getOpOperands()) + if (operand.get() == parent) + if (auto blk = + iface.getSuccessorBlockArgument(operand.getOperandNumber())) { + results.push_back(*blk); + return std::move(results); + } + } + + // assume all terminator operands potentially flow into all op results + for (auto res : op->getParentOp()->getResults()) + results.push_back(res); + + // assume all terminator operands potentially flow into all blockArgs in + // region + for (auto &blk : *block->getParent()) + for (auto arg : blk.getArguments()) + results.push_back(arg); + + // assume all terminator operands potentially flow into all other region + // entries + for (auto ® : op->getParentOp()->getRegions()) + for (auto arg : reg.front().getArguments()) + results.push_back(arg); + + return std::move(results); +} + +// For a result of an op, find all values which could flow into this result +static SmallVector getPotentialIncomingValues(OpResult res) { + Operation *owner = res.getOwner(); + SmallVector potentialSources; + + auto resultNo = res.getResultNumber(); + + if (auto iface = dyn_cast(owner)) { + SmallVector successors; + iface.getSuccessorRegions(RegionBranchPoint::parent(), successors); + for (auto &succ : successors) { + if (!succ.isParent()) + continue; + auto successorOperands = + llvm::to_vector(iface.getEntrySuccessorOperands(succ)); + + if (successorOperands.size() != owner->getNumResults()) { + llvm::errs() << *owner << "\n"; + } + assert(successorOperands.size() == owner->getNumResults() && + "expected all results to be populated with incoming operands"); + + potentialSources.push_back(successorOperands[resultNo]); + } + } else { + // assume all inputs potentially flow into all op results + for (auto operand : owner->getOperands()) { + potentialSources.push_back(operand); + } + } + + for (Region ®ion : owner->getRegions()) { + for (Block &block : region) { + // TODO: MLIR blocks without terminator? + if (auto iface = dyn_cast( + block.getTerminator())) { + // TODO: the interface may also tell us which regions are allowed to + // yield parent op results, and which only branch to other regions. + auto successorOperands = llvm::to_vector( + iface.getSuccessorOperands(RegionBranchPoint::parent())); + // TODO: understand/document the assumption of how operands flow. + + if (successorOperands.size() != owner->getNumResults()) { + llvm::errs() << *owner << "\n"; + } + assert(successorOperands.size() == owner->getNumResults() && + "expected all results to be populated with yielded " + "terminator operands"); + potentialSources.push_back(successorOperands[resultNo]); + } else { + // assume all terminator operands potentially flow into op results + for (Value v : block.getTerminator()->getOperands()) + potentialSources.push_back(v); + } + } + } + + return potentialSources; +} + +// For a blockargument, find all non-operand values which could flow into +// this result static SmallVector getPotentialIncomingValues(BlockArgument arg) { SetVector potentialSources; @@ -986,6 +1122,10 @@ static SmallVector getPotentialIncomingValues(BlockArgument arg) { potentialSources.insert(v); } } + + // and also any operand to the parent + for (auto op : parent->getOperands()) + potentialSources.insert(op); } return potentialSources.takeVector(); @@ -1121,13 +1261,6 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (Val.getType().isa()) return true; - // All function pointers are considered active in case an augmented primal - // or reverse is needed - if (Val.getDefiningOp() && - isa(Val.getDefiningOp())) { - return false; - } - /// If we've already shown this value to be inactive if (ConstantValues.find(Val) != ConstantValues.end()) { return true; @@ -1142,50 +1275,25 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (matchPattern(Val, m_Constant())) return true; - // if (auto CD = dyn_cast(Val)) { - // // inductively assume inactive - // ConstantValues.insert(CD); - // for (size_t i = 0, len = CD->getNumElements(); i < len; i++) { - // if (!isConstantValue(TR, CD->getElementAsConstant(i))) { - // ConstantValues.erase(CD); - // ActiveValues.insert(CD); - // return false; - // } - // } - // return true; - // } - // if (auto CD = dyn_cast(Val)) { - // // inductively assume inactive - // ConstantValues.insert(CD); - // for (size_t i = 0, len = CD->getNumOperands(); i < len; i++) { - // if (!isConstantValue(TR, CD->getOperand(i))) { - // ConstantValues.erase(CD); - // ActiveValues.insert(CD); - // return false; - // } - // } - // return true; - // } - if (Operation *definingOp = Val.getDefiningOp()) { - // Undef and non-global constants are inactive. - if (isa(definingOp)) { - return true; - } - - // Ops derived from intrinsics. - // NOTE: this was written with the assumption that Value is-a Operation, - // which is not the case in MLIR. - if (isa(definingOp)) { - return true; + if (auto ifaceOp = dyn_cast(definingOp)) { + if (ifaceOp.isInactive()) { + return true; + } } } if (auto arg = Val.dyn_cast()) { - auto funcIface = dyn_cast_or_null( - arg.getParentBlock()->getParentOp()); + // All arguments must be marked constant/nonconstant ahead of time + if (auto funcIface = dyn_cast_or_null( + arg.getParentBlock()->getParentOp())) + if (funcIface && arg.getOwner()->isEntryBlock() && + !funcIface.getArgAttr(arg.getArgNumber(), + LLVM::LLVMDialect::getByValAttrName())) { + llvm::errs() << funcIface << "\n"; + llvm::errs() << Val << "\n"; + assert(0 && "must've put arguments in constant/nonconstant"); + } // if (!funcIface || !arg.getOwner()->isEntryBlock()) { // TODO: we want a more advanced analysis based on MLIR interfaces here // For now, conservatively assume all block arguments are active @@ -1214,25 +1322,15 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // } // } // } - - // All arguments must be marked constant/nonconstant ahead of time - if (funcIface && arg.getOwner()->isEntryBlock() && - !funcIface.getArgAttr(arg.getArgNumber(), - LLVM::LLVMDialect::getByValAttrName())) { - llvm::errs() << funcIface << "\n"; - llvm::errs() << Val << "\n"; - assert(0 && "must've put arguments in constant/nonconstant"); - } } // This value is certainly an integer (and only and integer, not a pointer or // float). Therefore its value is constant if (TR.intType(1, Val, /*errIfNotFound*/ false).isIntegral()) { - // if (EnzymePrintActivity) - // llvm::errs() << " Value const as integral " << (int)directions << " " - // << *Val << " " - // << TR.intType(1, Val, /*errIfNotFound*/ false).str() << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Value const as integral " << (int)directions << " " + << Val << " " + << TR.intType(1, Val, /*errIfNotFound*/ false).str() << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1408,28 +1506,28 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (auto CI = Val.getDefiningOp()) { if (CI->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active val " << Val << "\n"; ActiveValues.insert(Val); return false; } if (CI->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive val " << Val << "\n"; InsertConstantValue(TR, Val); return true; } Operation *called = getFunctionFromCall(CI); if (called) { if (called->hasAttr("enzyme_active")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced active val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced active val " << Val << "\n"; ActiveValues.insert(Val); return false; } if (called->hasAttr("enzyme_inactive")) { - // if (EnzymePrintActivity) - // llvm::errs() << "forced inactive val " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "forced inactive val " << Val << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1480,20 +1578,31 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, LLVM::LLVMDialect::getByValAttrName())) { bool res = isConstantValue(TR, TmpOrig); if (res) { - // if (EnzymePrintActivity) - // llvm::errs() << " arg const from orig val=" << *Val - // << " orig=" << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " arg const from orig val=" << Val + << " orig=" << TmpOrig << "\n"; InsertConstantValue(TR, Val); } else { - // if (EnzymePrintActivity) - // llvm::errs() << " arg active from orig val=" << *Val - // << " orig=" << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " arg active from orig val=" << Val + << " orig=" << TmpOrig << "\n"; ActiveValues.insert(Val); } return res; } } + if (auto op = TmpOrig.getDefiningOp()) + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { + InsertConstantValue(TR, Val); + if (TmpOrig != Val) { + InsertConstantValue(TR, TmpOrig); + } + return true; + } + } + UpHypothesis = std::shared_ptr( new mlir::enzyme::ActivityAnalyzer(*this, UP)); UpHypothesis->ConstantValues.insert(Val); @@ -1751,10 +1860,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // argument if (TmpOrig != Val) { if (isConstantValue(TR, TmpOrig)) { - // if (EnzymePrintActivity) - // llvm::errs() << " Potential Pointer(" << (int)directions << ") " - // << *Val << " inactive from inactive origin " - // << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Potential Pointer(" << (int)directions << ") " + << Val << " inactive from inactive origin " << TmpOrig + << "\n"; InsertConstantValue(TR, Val); return true; } @@ -1766,11 +1875,17 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (!op || (!mayReadFromMemory(op) && !mayAllocateMemory(op))) { if (directions == UP && !Val.isa()) { if (isValueInactiveFromOrigin(TR, Val)) { + if (EnzymePrintActivity) + llvm::errs() << " Non-function value inactive from origin(" + << (int)directions << ") " << Val << "\n"; InsertConstantValue(TR, Val); return true; } } else { if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { + if (EnzymePrintActivity) + llvm::errs() << " Non-function value_v2 inactive from origin(" + << (int)directions << ") " << Val << "\n"; InsertConstantValue(TR, Val); insertConstantsFrom(TR, *UpHypothesis); return true; @@ -1784,9 +1899,9 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // can be loaded/stored cannot be assesed and therefore we default to assume // it to be active if (directions != (UP | DOWN)) { - // if (EnzymePrintActivity) - // llvm::errs() << " " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << Val << "\n"; ActiveValues.insert(Val); return false; } @@ -1814,13 +1929,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { Hypothesis->DeducingPointers.insert(Val); - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction hypothesis: " << *VI << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction hypothesis: " << Val << "\n"; } else { - // if (EnzymePrintActivity) - // llvm::errs() << " cannot show constant instruction hypothesis: " - // << *VI << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " cannot show constant instruction hypothesis: " + << Val << "\n"; } } @@ -1828,16 +1942,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, if (notForAnalysis.count(op->getBlock())) return false; - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("exit") || - iasm.getAsmString().contains("cpuid")) - return false; - } - if (isa(op)) { - return true; - } + if (auto op = TmpOrig.getDefiningOp()) + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { + return false; + } + } // If this is a malloc or free, this doesn't impact the activity if (auto CI = dyn_cast(op)) { @@ -1953,8 +2063,8 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If we haven't already shown a potentially active load // check if this loads the given value and is active if (!potentiallyActiveLoad && isRefSet(modRef)) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active load: " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active load: " << *op << "\n"; if (isa(op)) { // TODO: this assumption should be built into the MLIR interface // verifier, or alternatively we should relax it. @@ -1978,12 +2088,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, for (Operation *user : V.getUsers()) { if (mayWriteToMemory(user)) { if (!Hypothesis->isConstantOperation(TR, user)) { - // if (EnzymePrintActivity) - // llvm::errs() - // << "potential active store via " - // "pointer in load: " - // << *I << " of " << *Val << " via " << *U << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store via " + "pointer in load: " + << *op << " of " << Val << " via " + << *user << "\n"; potentiallyActiveStore = true; return true; } @@ -2064,10 +2173,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, !Hypothesis->isConstantValue(TR, V) && TR.query(V)[{-1}].isPossiblePointer(); })) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active store via pointer in " - // "unknown inst: " - // << *I << " of " << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store via pointer in " + "unknown inst: " + << *op << " of " << Val << "\n"; potentiallyActiveStore = true; } } @@ -2075,25 +2184,25 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } } if ((!potentiallyActiveStore || !potentialStore) && isModSet(modRef)) { - // if (EnzymePrintActivity) - // llvm::errs() << "potential active store: " << *I << " Val=" << *Val - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "potential active store: " << *op << " Val=" << Val + << "\n"; if (auto SI = dyn_cast(op)) { bool cop = !Hypothesis->isConstantValue(TR, SI.getValue()); - // if (EnzymePrintActivity) - // llvm::errs() << " -- store potential activity: " << (int)cop - // << " - " << *SI << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- store potential activity: " << (int)cop + << " - " << *SI << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; } else if (auto SI = dyn_cast(op)) { // FIXME: this is a copy-pasta form above to work with MLIR memrefs. bool cop = !Hypothesis->isConstantValue(TR, SI.getValueToStore()); - // if (EnzymePrintActivity) - // llvm::errs() << " -- store potential activity: " << (int)cop - // << " - " << *SI << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- store potential activity: " << (int)cop + << " - " << *SI << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; @@ -2110,11 +2219,10 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // TODO: note that this can be optimized (especially for function // calls) auto cop = !Hypothesis->isConstantOperation(TR, op); - // if (EnzymePrintActivity) - // llvm::errs() << " -- unknown store potential activity: " << - // (int)cop - // << " - " << *I << " of " - // << " Val=" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- unknown store potential activity: " << (int)cop + << " - " << *op << " of " + << " Val=" << Val << "\n"; potentialStore = true; if (cop) potentiallyActiveStore = true; @@ -2179,11 +2287,11 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, } activeLoadAndStore:; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *Val - // << " potentiallyActiveLoad=" << potentiallyActiveLoad - // << " potentiallyActiveStore=" << potentiallyActiveStore - // << " potentialStore=" << potentialStore << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << Val + << " potentiallyActiveLoad=" << potentiallyActiveLoad + << " potentiallyActiveStore=" << potentiallyActiveStore + << " potentialStore=" << potentialStore << "\n"; if (potentiallyActiveLoad && potentiallyActiveStore) { insertAllFrom(TR, *Hypothesis, Val); // TODO have insertall dependence on this @@ -2245,15 +2353,14 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, new mlir::enzyme::ActivityAnalyzer(*DownHypothesis, DOWN)); DownHypothesis2->ConstantValues.insert(TmpOrig); if (DownHypothesis2->isValueActivelyStoredOrReturned(TR, TmpOrig)) { - // if (EnzymePrintActivity) - // llvm::errs() << " active from ivasor: " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " active from ivasor: " << TmpOrig << "\n"; ActiveDown = true; } } else { // unknown origin that could've been stored/returned/etc - // if (EnzymePrintActivity) - // llvm::errs() << " active from unknown origin: " << *TmpOrig << - // "\n"; + if (EnzymePrintActivity) + llvm::errs() << " active from unknown origin: " << TmpOrig << "\n"; ActiveDown = true; } } @@ -2290,13 +2397,12 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If we go to an active return and only load it, however, that doesnt // transfer derivatives and we can say this memory is inactive - // if (EnzymePrintActivity) - // llvm::errs() << " @@MEMSEARCH" << (int)directions << ">" << *Val - // << " potentiallyActiveLoad=" << potentiallyActiveLoad - // << " potentialStore=" << potentialStore - // << " ActiveUp=" << ActiveUp << " ActiveDown=" << - // ActiveDown - // << " ActiveMemory=" << ActiveMemory << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " @@MEMSEARCH" << (int)directions << ">" << Val + << " potentiallyActiveLoad=" << potentiallyActiveLoad + << " potentialStore=" << potentialStore + << " ActiveUp=" << ActiveUp << " ActiveDown=" << ActiveDown + << " ActiveMemory=" << ActiveMemory << "\n"; if (ActiveMemory) { ActiveValues.insert(Val); @@ -2326,39 +2432,21 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // this value is inactive, we are inactive Since we won't look at uses to // prove, we can inductively assume this is inactive if (directions & UP) { - if (directions == UP && !Val.isa()) { - if (isValueInactiveFromOrigin(TR, Val)) { - InsertConstantValue(TR, Val); - return true; - } else if (Operation *op = Val.getDefiningOp()) { - if (directions == (UP | DOWN)) { - for (Value operand : op->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - for (Value result : op->getResults()) { - ReEvaluateValueIfInactiveValue[operand].insert(result); - } - } - } - } - } + UpHypothesis = std::shared_ptr( + new mlir::enzyme::ActivityAnalyzer(*this, UP)); + UpHypothesis->ConstantValues.insert(Val); + SmallPtrSet toredo; + if (UpHypothesis->isValueInactiveFromOrigin(TR, Val, &toredo)) { + insertConstantsFrom(TR, *UpHypothesis); + InsertConstantValue(TR, Val); + if (EnzymePrintActivity) + llvm::errs() << " Value constant from origin [" << (int)directions + << "]" << Val << "\n"; + return true; } else { - UpHypothesis = std::shared_ptr( - new mlir::enzyme::ActivityAnalyzer(*this, UP)); - UpHypothesis->ConstantValues.insert(Val); - if (UpHypothesis->isValueInactiveFromOrigin(TR, Val)) { - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } else if (Operation *op = Val.getDefiningOp()) { - if (directions == (UP | DOWN)) { - for (Value operand : op->getOperands()) { - if (!UpHypothesis->isConstantValue(TR, operand)) { - for (Value result : op->getResults()) { - ReEvaluateValueIfInactiveValue[operand].insert(result); - } - } - } - } + for (Value result : toredo) { + if (result != Val) + ReEvaluateValueIfInactiveValue[result].insert(Val); } } } @@ -2368,164 +2456,54 @@ bool mlir::enzyme::ActivityAnalyzer::isConstantValue(MTypeResults const &TR, // If all users are inactive, this is therefore inactive. // Since we won't look at origins to prove, we can inductively assume this // is inactive - - // As an optimization if we are going down already - // and we won't use ourselves (done by PHI's), we - // dont need to inductively assume we're true - // and can instead use this object! - if (directions == DOWN && !Val.isa()) { - if (isValueInactiveFromUsers(TR, Val, UseActivity::None)) { - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } - } else { - auto DownHypothesis = std::shared_ptr( - new mlir::enzyme::ActivityAnalyzer(*this, DOWN)); - DownHypothesis->ConstantValues.insert(Val); - if (DownHypothesis->isValueInactiveFromUsers(TR, Val, - UseActivity::None)) { - insertConstantsFrom(TR, *DownHypothesis); - if (UpHypothesis) - insertConstantsFrom(TR, *UpHypothesis); - InsertConstantValue(TR, Val); - return true; - } + auto DownHypothesis = std::shared_ptr( + new mlir::enzyme::ActivityAnalyzer(*this, DOWN)); + DownHypothesis->ConstantValues.insert(Val); + if (DownHypothesis->isValueInactiveFromUsers(TR, Val, UseActivity::None)) { + insertConstantsFrom(TR, *DownHypothesis); + if (UpHypothesis) + insertConstantsFrom(TR, *UpHypothesis); + InsertConstantValue(TR, Val); + return true; } } - // if (EnzymePrintActivity) - // llvm::errs() << " Value nonconstant (couldn't disprove)[" << - // (int)directions - // << "]" << *Val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " Value nonconstant (couldn't disprove)[" << (int)directions + << "]" << Val << "\n"; ActiveValues.insert(Val); return false; } /// Is the value guaranteed to be inactive because of how it's produced. bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromOrigin( - MTypeResults const &TR, Value val) { + MTypeResults const &TR, Value val, SmallPtrSetImpl *inactArg) { // Must be an analyzer only searching up assert(directions == UP); - // TODO: use getPotentialIncomingValues here to avoid duplciation. - if (auto arg = val.dyn_cast()) { - if (arg.getOwner()->isEntryBlock()) { - Operation *parent = arg.getOwner()->getParentOp(); - Region *parentRegion = arg.getOwner()->getParent(); - SetVector potentialSources; - // Use region interface to find the values flowing into the entry block. - if (auto iface = dyn_cast(parent)) { - auto isRegionSucessorOf = [arg](RegionBranchOpInterface iface, - Region *region, - RegionBranchPoint predecessor, - SetVector &potentialSources) { - SmallVector successors; - iface.getSuccessorRegions(predecessor, successors); - for (const RegionSuccessor &successor : successors) { - if (successor.getSuccessor() != region) - continue; - - unsigned operandOffset = static_cast(-1); - for (const auto &en : - llvm::enumerate(successor.getSuccessorInputs())) { - if (en.value() != arg) - continue; - operandOffset = en.index(); - } - assert(operandOffset != static_cast(-1) && - "could not locate the position of the argument in the " - "successor input list"); - - // Find the values that are forwarded to entry block arguments of - // the current region. - if (predecessor.isParent()) { - // XXX: this assumes a contiguous slice of operands is mapped 1-1 - // without swaps to a contiguous slice of entry block arguments. - assert(iface.getEntrySuccessorOperands(region).size() == - successor.getSuccessorInputs().size()); - potentialSources.insert( - iface.getEntrySuccessorOperands(region)[operandOffset]); - } else { - // Find all block terminators in the predecessor region that - // may be branching to this region, and get the operands they - // forward. - for (Block &block : *predecessor.getRegionOrNull()) { - // TODO: MLIR block without terminator - if (auto terminator = - dyn_cast( - block.getTerminator())) { - // XXX: this assumes a contiguous slice of operands is mapped - // 1-1 without swaps to a contiguous slice of entry block - // arguments. - assert(terminator.getSuccessorOperands(region).size() == - successor.getSuccessorInputs().size()); - potentialSources.insert( - terminator.getSuccessorOperands(region)[operandOffset]); - } else { - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); - } - } - } - } - }; - - // Find all possible source regions for the current region. - isRegionSucessorOf(iface, parentRegion, RegionBranchPoint::parent(), - potentialSources); - for (Region ®ion : parent->getRegions()) - isRegionSucessorOf(iface, parentRegion, region, potentialSources); - - } else { - // Conservatively assume any op operand and any terminator operand of - // any region can flow into any block argument. - for (Region ®ion : parent->getRegions()) { - for (Block &block : region) { - // TODO: MLIR blocks without terminator? - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); - } - } - } - - return llvm::all_of(potentialSources, [&](Value value) { - return isConstantValue(TR, value); - }); - } - - // Look at values flowing into block arguments. - for (Block *predecessor : arg.getOwner()->getPredecessors()) { - Operation *terminator = predecessor->getTerminator(); - if (auto iface = dyn_cast(terminator)) { - for (const auto &en : llvm::enumerate(predecessor->getSuccessors())) { - if (en.value() != arg.getOwner()) - continue; - - Value inflow = iface.getSuccessorOperands(en.index()) - .getForwardedOperands()[arg.getArgNumber()]; - if (!isConstantValue(TR, inflow)) - return false; - } - } else { - for (Value operand : terminator->getOperands()) { - if (!isConstantValue(TR, operand)) - return false; + for (auto v : getPotentialIncomingValues(arg)) { + if (!isConstantValue(TR, v)) { + if (EnzymePrintActivity) { + llvm::errs() << " blockarg: " << arg + << " may be active due to inflow from " << v << "\n"; } + if (inactArg) + inactArg->insert(v); + return false; } } - return true; } return isOperationInactiveFromOrigin(TR, val.getDefiningOp(), - val.cast().getResultNumber()); + val.cast().getResultNumber(), + inactArg); } bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( - MTypeResults const &TR, Operation *op, std::optional resultNo) { + MTypeResults const &TR, Operation *op, std::optional resultNo, + SmallPtrSetImpl *inactArg) { // Must be an analyzer only searching up assert(directions == UP); @@ -2537,30 +2515,28 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( return false; } - // if (EnzymePrintActivity) - // llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *inst << - // "\n"; - - // cpuid is explicitly an inactive instruction - if (auto iasm = dyn_cast(op)) { - if (iasm.getAsmString().contains("cpuid")) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction from known cpuid instruction - // " - // << *inst << "\n"; + if (auto ifaceOp = dyn_cast(op)) { + if (ifaceOp.isInactive()) { return true; } } + if (EnzymePrintActivity) + llvm::errs() << " < UPSEARCH" << (int)directions << ">" << *op << "\n"; + if (auto store = dyn_cast(op)) { if (isConstantValue(TR, store.getValue()) || isConstantValue(TR, store.getAddr())) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as store operand is inactive - // " - // << *inst << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as store operand is inactive" + << *op << "\n"; return true; } + if (inactArg) { + inactArg->insert(store.getValue()); + inactArg->insert(store.getAddr()); + } + return false; } if (isa(op)) { @@ -2568,11 +2544,15 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( // values and thus the store is inactive if (isConstantValue(TR, op->getOperand(0)) || isConstantValue(TR, op->getOperand(1))) { - // if (EnzymePrintActivity) - // llvm::errs() << " constant instruction as memtransfer " << *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " constant instruction as memtransfer " << *op << "\n"; return true; } + if (inactArg) { + inactArg->insert(op->getOperand(0)); + inactArg->insert(op->getOperand(1)); + } + return false; } if (auto call = dyn_cast(op)) { @@ -2614,9 +2594,9 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (KnownInactiveFunctions.count(funcName.str()) || MPIInactiveCommAllocators.find(funcName.str()) != MPIInactiveCommAllocators.end()) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions - // << ") up-knowninactivecall " << *inst << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions + << ") up-knowninactivecall " << *op << "\n"; return true; } @@ -2636,32 +2616,26 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( // return true; // } Value callVal = call.getCallableForCallee().dyn_cast(); - if (isConstantValue(TR, callVal)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-constfn " - // << *inst << " - " << *callVal << "\n"; - return true; - } - } - // Intrinsics known always to be inactive - if (isa(op)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-intrinsic " - // << *inst << "\n"; - return true; + if (callVal) + if (isConstantValue(TR, callVal)) { + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-constfn " + << *op << " - " << callVal << "\n"; + return true; + } } if (auto gep = dyn_cast(op)) { // A gep's only args that could make it active is the pointer operand if (isConstantValue(TR, gep.getBase())) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-gep " << - // *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-gep " << *op + << "\n"; return true; } + if (inactArg) { + inactArg->insert(gep.getBase()); + } return false; } @@ -2722,103 +2696,79 @@ bool mlir::enzyme::ActivityAnalyzer::isOperationInactiveFromOrigin( if (isConstantValue(TR, si.getTrueValue()) && isConstantValue(TR, si.getFalseValue())) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-sel:" << - // *inst - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-sel:" << *op + << "\n"; return true; } + if (inactArg) { + inactArg->insert(si.getTrueValue()); + inactArg->insert(si.getFalseValue()); + } return false; } - if (isa(op)) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-fpcst:" << - // *inst - // << "\n"; - return true; - } else { - bool seenuse = false; - //! TODO does not consider reading from global memory that is active and not - //! an argument + if (!resultNo) { for (Value a : op->getOperands()) { bool hypval = isConstantValue(TR, a); if (!hypval) { - // if (EnzymePrintActivity) - // llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " - // << *inst << " op " << *a << "\n"; - seenuse = true; - break; - } - } - if (!resultNo) { - // Conservatively check all top-level operations nested in the region, - // there is recursion there. - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // XXX: We think that we don't need to check block arguments here - // because values flow into them either from operands of the parent op - // or from the op itself. - if (llvm::any_of(block, [&](Operation &nested) { - // No need to check the results, even if they may be active, - // because in absence of resultNo, we are checking for the - // entire op being inactive not individual values. - // - // // The loop _operation_ is inactive, but the result is, just - // // like the GEP inside it. - // %r = scf.for %i.. { - // // The GEP operation is not active, but the result is. - // %active_r = llvm.gep ... %active_operand - // scf.yield %active_r - // } - return !isConstantOperation(TR, &nested); - })) { - seenuse = true; - break; - } + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " + << *op << " op " << a << "\n"; + if (inactArg) { + inactArg->insert(a); } - if (seenuse) - break; + return false; } - } else { - SetVector potentialSources; - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - // TODO: MLIR blocks without terminator? - if (auto iface = dyn_cast( - block.getTerminator())) { - // TODO: the interface may also tell us which regions are allowed to - // yield parent op results, and which only branch to other regions. - auto successorOperands = llvm::to_vector( - iface.getSuccessorOperands(RegionBranchPoint::parent())); - // TODO: understand/document the assumption of how operands flow. - assert(successorOperands.size() == op->getNumResults() && - "expected all results to be populated with yielded " - "terminator operands"); - potentialSources.insert(successorOperands[*resultNo]); - } else { - // assume all terminator operands potentially flow into op results - for (Value v : block.getTerminator()->getOperands()) - potentialSources.insert(v); + } + // Conservatively check all top-level operations nested in the region, + // there is recursion there. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + // XXX: We think that we don't need to check block arguments here + // because values flow into them either from operands of the parent op + // or from the op itself. + for (Operation &nested : block) { + // No need to check the results, even if they may be active, + // because in absence of resultNo, we are checking for the + // entire op being inactive not individual values. + // + // // The loop _operation_ is inactive, but the result is, just + // // like the GEP inside it. + // %r = scf.for %i.. { + // // The GEP operation is not active, but the result is. + // %active_r = llvm.gep ... %active_operand + // scf.yield %active_r + // } + if (!isConstantOperation(TR, &nested)) { + // TODO set inactArg here, except with constant operand. + // assert(!inactArg); + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions + << ") up-inst-op " << *op << " sub-op " << nested + << "\n"; + return false; } } } - if (llvm::any_of(potentialSources, [&](Value value) { - return !isConstantValue(TR, value); - })) { - seenuse = true; - } } - - if (!seenuse) { - // if (EnzymePrintActivity) - // llvm::errs() << "constant(" << (int)directions << ") up-inst:" << - // *inst - // << "\n"; - return true; + } else { + for (auto value : getPotentialIncomingValues(op->getResult(*resultNo))) { + if (!isConstantValue(TR, value)) { + if (EnzymePrintActivity) + llvm::errs() << "nonconstant(" << (int)directions << ") up-inst " + << *op << " value " << value << "\n"; + if (inactArg) + inactArg->insert(value); + return false; + } } - return false; } + + if (EnzymePrintActivity) + llvm::errs() << "constant(" << (int)directions << ") up-inst:" << *op + << "\n"; + return true; } /// Is the value free of any active uses @@ -2830,9 +2780,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // To ensure we can call down - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << " UA=" << (int)PUA << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << " UA=" << (int)PUA << "\n"; bool seenuse = false; // user, predecessor @@ -2871,9 +2821,23 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // } } - // if (EnzymePrintActivity) - // llvm::errs() << " considering use of " << *val << " - " << *a - // << "\n"; + if (UA != UseActivity::AllStores) { + if (auto ifaceOp = dyn_cast(a)) { + bool allInactive = true; + for (OpOperand &operand : a->getOpOperands()) { + if (parent == operand.get() && + !ifaceOp.isArgInactive(operand.getOperandNumber())) { + allInactive = false; + break; + } + } + if (allInactive) + continue; + } + } + + if (EnzymePrintActivity) + llvm::errs() << " considering use of " << val << " - " << *a << "\n"; // Only ignore stores to the operand, not storing the operand // somewhere @@ -2952,25 +2916,23 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // vtodo.push_back(TmpOrig_2); // continue; // } - // if (EnzymePrintActivity) - // llvm::errs() << " -- cannot continuing indirect store from - // " - // << *val << " due to " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- cannot continuing indirect store from" + << val << " due to " << TmpOrig << "\n"; shouldContinue = false; break; } if (shouldContinue) { - // if (EnzymePrintActivity) - // llvm::errs() << " -- continuing indirect store from " << - // *val - // << " into:\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- continuing indirect store from " << val + << " into:\n"; done.insert(std::make_tuple(SI.getOperation(), SI.getValue(), UA)); for (Value TmpOrig : newAllocaSet) { for (Operation *a : TmpOrig.getUsers()) { todo.push_back(std::make_tuple(a, TmpOrig, UA)); - // if (EnzymePrintActivity) - // llvm::errs() << " ** " << *a << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " ** " << *a << "\n"; } AllocaSet.insert(TmpOrig); shouldContinue = true; @@ -3030,10 +2992,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( break; } if (shouldContinue) { - // if (EnzymePrintActivity) - // llvm::errs() << " -- continuing indirect store2 from " << - // *val - // << " via " << *TmpOrig << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " -- continuing indirect store2 from " << val + << " via " << TmpOrig << "\n"; continue; } } @@ -3071,18 +3032,9 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // } if (isa(a)) { - // if (EnzymePrintActivity) - // llvm::errs() << "found constant(" << (int)directions - // << ") allocainst use:" << *val << " user " << *a << - // "\n"; - continue; - } - - if (isa( - a)) { - // if (EnzymePrintActivity) - // llvm::errs() << "found constant(" << (int)directions - // << ") si-fp use:" << *val << " user " << *a << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "found constant(" << (int)directions + << ") allocainst use:" << val << " user " << *a << "\n"; continue; } @@ -3111,12 +3063,21 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( // continue; // This use is only active if specified - if (isa(a)) { - if (ActiveReturns == DIFFE_TYPE::CONSTANT && - UA != UseActivity::AllStores) { + if (UA != UseActivity::AllStores) { + if (auto termUsers = getPotentialTerminatorUsers(a, parent)) { + for (auto post : *termUsers) { + for (Operation *postUser : post.getUsers()) { + todo.push_back(std::make_tuple(postUser, post, UA)); + } + } continue; - } else { - return false; + } + if (isFunctionReturn(a)) { + if (ActiveReturns == DIFFE_TYPE::CONSTANT) { + continue; + } else { + return false; + } } } @@ -3211,10 +3172,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (auto call = dyn_cast(a)) { bool ConstantArg = isFunctionArgumentConstant(call, parent); if (ConstantArg && UA != UseActivity::AllStores) { - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant callinst use:" << *val - // << " user " << *call << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant callinst use:" << val + << " user " << *call << "\n"; + } continue; } @@ -3235,10 +3196,8 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (operand.getDefiningOp()) { bool legal = true; - for (unsigned i = 0; i < call.getArgOperands().size() + 1; ++i) { - // FIXME: this is based on an assumption that the callee operand - // precedes arg operands. - Value a = call->getOperand(i); + for (unsigned i = 0; i < call.getArgOperands().size(); ++i) { + Value a = call.getArgOperands()[i]; // FIXME: yet another ingrained assumption that integers cannot be // active. @@ -3312,11 +3271,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( if (Operation *I = a) { if (notForAnalysis.count(I->getBlock())) { // TODO(PR #904): replace the "EnzymePrintActivity" flag with LLVM_DEBUG - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant unreachable inst use:" << - // *val - // << " user " << *I << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant unreachable inst use:" << val + << " user " << *I << "\n"; + } continue; } if (UA != UseActivity::AllStores && ConstantOperations.count(I)) { @@ -3325,11 +3283,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( .isa() || ConstantValues.count(val); })) { - // if (EnzymePrintActivity) { - // llvm::errs() << "Value found constant inst use:" << *val << " - // user " - // << *I << "\n"; - // } + if (EnzymePrintActivity) { + llvm::errs() << "Value found constant inst use:" << val << " user " + << *I << "\n"; + } continue; } UseActivity NU = UA; @@ -3367,10 +3324,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( LLVM::SExtOp, LLVM::ZExtOp, LLVM::TruncOp, - LLVM::SIToFPOp, - LLVM::UIToFPOp, - LLVM::FPToSIOp, - LLVM::FPToUIOp, LLVM::FPExtOp, LLVM::FPTruncOp // clang-format on @@ -3408,17 +3361,16 @@ bool mlir::enzyme::ActivityAnalyzer::isValueInactiveFromUsers( *FoundInst = I; } - // if (EnzymePrintActivity) - // llvm::errs() << "Value nonconstant inst (uses):" << *val << " user " << - // *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << "Value nonconstant inst (uses):" << val << " user " << *a + << "\n"; seenuse = true; break; } - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << "\n"; return !seenuse; } @@ -3435,10 +3387,10 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( return StoredOrReturnedCache[key]; } - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << "\n"; StoredOrReturnedCache[key] = false; @@ -3451,14 +3403,34 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( continue; } - if (isa(a)) { + if (auto ifaceOp = dyn_cast(a)) { + bool allInactive = true; + for (OpOperand &operand : a->getOpOperands()) { + if (operand.get() == val && + !ifaceOp.isArgInactive(operand.getOperandNumber())) { + allInactive = false; + break; + } + } + if (allInactive) + continue; + } + + if (auto termUsers = getPotentialTerminatorUsers(a, val)) { + for (auto post : *termUsers) + if (isValueActivelyStoredOrReturned(TR, post, outside)) { + return StoredOrReturnedCache[key] = true; + } + return false; + } + if (isFunctionReturn(a)) { if (ActiveReturns == DIFFE_TYPE::CONSTANT) continue; - // if (EnzymePrintActivity) - // llvm::errs() << " " - // << " active from-ret>" << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " + << " active from-ret>" << val << "\n"; StoredOrReturnedCache[key] = true; return true; } @@ -3481,11 +3453,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // Storing into active value, return true if (!isConstantValue(TR, SI.getValue())) { StoredOrReturnedCache[key] = true; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val - // << " store into=" << *SI << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val + << " store into=" << *SI << "\n"; return true; } } @@ -3494,11 +3466,11 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // Storing into active memory, return true if (!isConstantValue(TR, SI.getAddr())) { StoredOrReturnedCache[key] = true; - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << " store=" << *SI - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << " store=" << *SI + << "\n"; return true; } continue; @@ -3509,19 +3481,6 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // TODO: in MLIR, users are always operations // if (Operation *inst = a) { - auto mayWriteToMemory = [](Operation *op) { - auto iface = dyn_cast(op); - if (!iface) - return true; - - SmallVector effects; - iface.getEffects(effects); - return llvm::any_of( - effects, [](const MemoryEffects::EffectInstance &effect) { - return isa(effect.getEffect()); - }); - }; - if (!mayWriteToMemory(inst) /*|| (isa(inst) && AA.onlyReadsMemory(cast(inst)))*/) { // // if not written to memory and returning a known constant, this @@ -3562,18 +3521,17 @@ bool mlir::enzyme::ActivityAnalyzer::isValueActivelyStoredOrReturned( // it is written to active memory // TODO handle more memory instructions above to be less conservative - // if (EnzymePrintActivity) - // llvm::errs() << " " << *val << " - use=" << *a - // << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " << val << " - use=" << *a << "\n"; return StoredOrReturnedCache[key] = true; } - // if (EnzymePrintActivity) - // llvm::errs() << " " - // << *val << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " " + << val << "\n"; return false; } @@ -3589,9 +3547,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantOperation( if (!ActiveValues.count(toeval)) continue; ActiveValues.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of val " << *toeval - // << " due to inst " << *I << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of val " << toeval + << " due to inst " << *I << "\n"; isConstantValue(TR, toeval); } } @@ -3607,9 +3565,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantValue(MTypeResults const &TR, if (!ActiveValues.count(toeval)) continue; ActiveValues.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of val " << *toeval - // << " due to value " << *V << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of val " << toeval + << " due to value " << V << "\n"; isConstantValue(TR, toeval); } } @@ -3621,9 +3579,9 @@ void mlir::enzyme::ActivityAnalyzer::InsertConstantValue(MTypeResults const &TR, if (!ActiveOperations.count(toeval)) continue; ActiveOperations.erase(toeval); - // if (EnzymePrintActivity) - // llvm::errs() << " re-evaluating activity of inst " << *toeval - // << " due to value " << *V << "\n"; + if (EnzymePrintActivity) + llvm::errs() << " re-evaluating activity of inst " << *toeval + << " due to value " << V << "\n"; isConstantOperation(TR, toeval); } } diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h index bff41155fc3b..c73df2025883 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.h @@ -143,12 +143,18 @@ class ActivityAnalyzer { bool isFunctionArgumentConstant(mlir::CallOpInterface CI, Value val); /// Is the value guaranteed to be inactive because of how it's produced. - bool isValueInactiveFromOrigin(MTypeResults const &TR, Value val); + /// If active and inactArg is non-null, store any values which may allow this + /// to succeed in the future + bool isValueInactiveFromOrigin( + MTypeResults const &TR, Value val, + llvm::SmallPtrSetImpl *inactArg = nullptr); + /// Is the operation guaranteed to be inactive because of how its operands are /// produced. bool isOperationInactiveFromOrigin( MTypeResults const &TR, Operation *op, - std::optional resultNo = std::nullopt); + std::optional resultNo = std::nullopt, + llvm::SmallPtrSetImpl *inactArg = nullptr); public: enum class UseActivity { diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index e0a2fa1c0c38..6fbfbacda7d9 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -211,7 +211,7 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, // TODO: consider a stricter check that we only replace unknown // values or a value with itself, currently blocked by memalign. AliasClassSet valuesCopy(values); - valuesCopy.join(it->getSecond()); + (void)valuesCopy.join(it->getSecond()); values.print(llvm::errs()); llvm::errs() << "\n"; it->getSecond().print(llvm::errs()); @@ -598,7 +598,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( if (funcMayReadOther) { // If a function may read from other, it may be storing pointers from // unknown alias sets into any writable pointer. - functionMayCapture.markUnknown(); + (void)functionMayCapture.markUnknown(); } else { for (int pointerAsData : pointerLikeOperands) { // If not captured, it cannot be stored in anything. @@ -609,7 +609,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( const auto *srcClasses = getOrCreateFor( call, call.getArgOperands()[pointerAsData]); - functionMayCapture.join(srcClasses->getAliasClassesObject()); + (void)functionMayCapture.join(srcClasses->getAliasClassesObject()); } } @@ -624,16 +624,17 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // If the argument cannot be stored into, just preserve it as is. if (!mayWriteArg(callee, pointerOperand, argModRef)) { - nonWritableOperandClasses.join(destClasses->getAliasClassesObject()); + (void)nonWritableOperandClasses.join( + destClasses->getAliasClassesObject()); continue; } - writableClasses.join(destClasses->getAliasClassesObject()); + (void)writableClasses.join(destClasses->getAliasClassesObject()); // If the destination class is unknown, mark all known classes // pessimistic (alias classes that have not beed analyzed and thus are // absent from pointsTo are treated as "undefined" at this point). if (destClasses->isUnknown()) { - writableClasses.markUnknown(); + (void)writableClasses.markUnknown(); changed |= after->markAllPointToUnknown(); break; } @@ -701,15 +702,15 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( AliasClassSet resultWithoutNonWritableOperands = AliasClassSet::getUndefined(); if (destClasses->isUnknown() || nonWritableOperandClasses.isUnknown()) { - resultWithoutNonWritableOperands.markUnknown(); + (void)resultWithoutNonWritableOperands.markUnknown(); } else if (!destClasses->isUndefined() && !nonWritableOperandClasses.isUndefined()) { DenseSet nonOperandClasses = llvm::set_difference(destClasses->getAliasClasses(), nonWritableOperandClasses.getAliasClasses()); - resultWithoutNonWritableOperands.insert(nonOperandClasses); + (void)resultWithoutNonWritableOperands.insert(nonOperandClasses); } else { - resultWithoutNonWritableOperands.join( + (void)resultWithoutNonWritableOperands.join( destClasses->getAliasClassesObject()); } @@ -999,7 +1000,7 @@ void enzyme::AliasAnalysis::visitOperation( if (!isPointerLike(result.getType())) continue; - results[result.getResultNumber()]->markUnknown(); + (void)results[result.getResultNumber()]->markUnknown(); } return; } @@ -1053,7 +1054,7 @@ void enzyme::AliasAnalysis::visitExternalCall( continue; const AliasClassLattice *srcClasses = operands[operandNo]; - operandAliasClasses.join(srcClasses->getAliasClassesObject()); + (void)operandAliasClasses.join(srcClasses->getAliasClassesObject()); if (!mayReadArg(callee, operandNo, argModRef)) continue; @@ -1061,13 +1062,14 @@ void enzyme::AliasAnalysis::visitExternalCall( // If can read from argument, collect the alias classes that can this // argument may be pointing to. const auto *pointsToLattice = getOrCreateFor(call, call); - srcClasses->getAliasClassesObject().foreachClass( + (void)srcClasses->getAliasClassesObject().foreachClass( [&](DistinctAttr srcClass, AliasClassSet::State state) { // Nothing to do in top/bottom case. In the top case, we have already // set `operandAliasClasses` to top above. if (srcClass == nullptr) return ChangeResult::NoChange; - operandAliasClasses.join(pointsToLattice->getPointsTo(srcClass)); + (void)operandAliasClasses.join( + pointsToLattice->getPointsTo(srcClass)); return ChangeResult::NoChange; }); } diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 2896b428a23d..00b56c9a7bd1 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -44,6 +44,8 @@ #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" +#include "Interfaces/AutoDiffOpInterface.h" + using namespace mlir; using namespace mlir::dataflow; using enzyme::AliasClassLattice; @@ -329,7 +331,7 @@ class MemoryActivity : public AbstractDenseLattice { const MemoryActivityState *rhsActivity = isKnownInRHS ? &rhsIt->getSecond() : &rhs.otherMemoryActivity; MemoryActivityState updatedActivity(*lhsActivity); - updatedActivity.merge(*rhsActivity); + (void)updatedActivity.merge(*rhsActivity); if ((lhsIt != activityStates.end() && updatedActivity != lhsIt->getSecond()) || (lhsIt == activityStates.end() && @@ -491,7 +493,7 @@ std::optional getCopySource(Operation *op) { /// If the classes are undefined, the callback will not be called at all. void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass, function_ref forEachFn) { - ptrAliasClass->getAliasClassesObject().foreachClass( + (void)ptrAliasClass->getAliasClassesObject().foreachClass( [&](DistinctAttr alloc, enzyme::AliasClassSet::State state) { if (state != enzyme::AliasClassSet::State::Undefined) forEachFn(alloc); @@ -511,6 +513,15 @@ class DenseForwardActivityAnalysis ForwardMemoryActivity *after) override { join(after, before); ChangeResult result = ChangeResult::NoChange; + + // TODO If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // propagateIfChanged(after, result); + // return; + // } + // } + auto memory = dyn_cast(op); // If we can't reason about the memory effects, then conservatively assume // we can't deduce anything about activity via side-effects. @@ -660,6 +671,14 @@ class DenseBackwardActivityAnalysis void visitOperation(Operation *op, const BackwardMemoryActivity &after, BackwardMemoryActivity *before) override { + + // TODO: If we know this is inactive by definition + // if (auto ifaceOp = dyn_cast(op)) { + // if (ifaceOp.isInactive()) { + // return; + // } + // } + // Initialize the return activity of arguments. if (op->hasTrait() && op->getParentOp() == parentOp) { for (const auto &[arg, argActivity] : @@ -846,15 +865,16 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, // } // }; auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) { - aliasClasses.foreachClass([&](DistinctAttr neighbor, - enzyme::AliasClassSet::State state) { - assert(neighbor && "unhandled undefined/unknown case before visit"); - if (!visited.contains(neighbor)) { - visited.insert(neighbor); - frontier.push_back(neighbor); - } - return ChangeResult::NoChange; - }); + (void)aliasClasses.foreachClass( + [&](DistinctAttr neighbor, enzyme::AliasClassSet::State state) { + assert(neighbor && + "unhandled undefined/unknown case before visit"); + if (!visited.contains(neighbor)) { + visited.insert(neighbor); + frontier.push_back(neighbor); + } + return ChangeResult::NoChange; + }); }; if (isa_and_present( @@ -1081,7 +1101,7 @@ void enzyme::runDataFlowActivityAnalysis( // analyses, enzyme_const is the default. if (activity == enzyme::Activity::enzyme_out) { auto *argLattice = solver.getOrCreateState(arg); - argLattice->join(ValueActivity::getActiveVal()); + (void)argLattice->join(ValueActivity::getActiveVal()); } } @@ -1096,7 +1116,7 @@ void enzyme::runDataFlowActivityAnalysis( solver.getOrCreateState(operand); // Very basic type inference of the type if (isa(operand.getType())) { - returnLattice->meet(ValueActivity::getActiveVal()); + (void)returnLattice->meet(ValueActivity::getActiveVal()); } } } diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index de51913400a6..7d74fdafbdbf 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -83,12 +83,6 @@ def PopOp : Enzyme_Op<"pop"> { let results = (outs AnyType:$output); } -def ClearOp : Enzyme_Op<"clear"> { - let summary = "Remove top element from ShadowedGradient"; - let arguments = (ins AnyType : $cache); - let results = (outs ); -} - def InitOp : Enzyme_Op<"init"> { let summary = "Creat enzyme.gradient and enzyme.cache"; let arguments = (ins ); @@ -105,36 +99,28 @@ def Cache : Enzyme_Type<"Cache"> { let assemblyFormat = "`<` $type `>`"; } -def SetOp : Enzyme_Op<"set"> { - let summary = "Write to gradient"; - let arguments = (ins AnyType : $gradient, AnyType : $value); - let results = (outs ); -} - -def GetOp : Enzyme_Op<"get"> { - let summary = "Load value of gradient"; - let arguments = (ins AnyType : $gradient); - let results = (outs AnyType); -} - def Gradient : Enzyme_Type<"Gradient"> { - let summary = "Stores gradient if it cant be stroed in a value."; + let summary = "Mutable storage for accumulating gradients"; let description = [{ - "Cache for reverse pass" + Mutable storage for accumulating derivatives of immutable types (e.g. adding all the partial derivatives from users of a float64) }]; let parameters = (ins "Type":$basetype); let mnemonic = "Gradient"; let assemblyFormat = "`<` $basetype `>`"; } -def ShadowedGradient : Enzyme_Type<"ShadowedGradient"> { - let summary = "Stores gradients which need to be initialized with shadow values from the forward pass."; - let description = [{ - "Cache for reverse pass" - }]; - let parameters = (ins "Type":$basetype); - let mnemonic = "ShadowedGradient"; - let assemblyFormat = "`<` $basetype `>`"; +def SetOp : Enzyme_Op<"set"> { + let summary = "Store the current value of the gradient"; + let arguments = (ins Arg:$gradient, AnyType : $value); + let results = (outs ); +} + +def GetOp : Enzyme_Op<"get"> { + let summary = "Load current value of gradient"; + let arguments = (ins Arg:$gradient); + let results = (outs AnyType); } def AddToOp : Enzyme_Op<"addTo", [Pure, Terminator, ReturnLike]>, diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..c27f0d60d129 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/AffineAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,64 @@ +//===- AffineAutoDiffOpInterfaceImpl.cpp - Interface external model -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR Affine dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/IntegerSet.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +affine::AffineForOp +createAffineForWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, TypeRange rettys) { + affine::AffineForOpAdaptor adaptor(remappedOperands, + cast(original)); + auto repFor = builder.create( + original->getLoc(), adaptor.getLowerBoundOperands(), + adaptor.getLowerBoundMap(), adaptor.getUpperBoundOperands(), + adaptor.getUpperBoundMap(), adaptor.getStep().getZExtValue(), + // This dance is necessary because the adaptor accessors are based on the + // internal attribute containing the number of operands associated with + // each named operand group. This attribute is carried over from the + // original operation and does not account for the shadow-related iter + // args. Instead, assume lower/upper bound operands must not have shadows + // since they are integer-typed and take the result of operands as iter + // args. + remappedOperands.drop_front(adaptor.getLowerBoundOperands().size() + + adaptor.getUpperBoundOperands().size())); + return repFor; +} + +affine::AffineIfOp createAffineIfWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, + affine::AffineIfOp original, + ValueRange remappedOperands, + TypeRange rettys) { + affine::AffineIfOpAdaptor adaptor(remappedOperands, original); + return builder.create( + original->getLoc(), rettys, original.getIntegerSet(), + adaptor.getOperands(), !original.getElseRegion().empty()); +} + +#include "Implementations/AffineDerivatives.inc" +} // namespace + +void mlir::enzyme::registerAffineDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, affine::AffineDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td new file mode 100644 index 000000000000..9f22d00cecdb --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/AffineDerivatives.td @@ -0,0 +1,28 @@ +include "Common.td" + +def : ControlFlowOp<"affine", "AffineForOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return createAffineForWithShadows(op, builder, gutils, original, + remappedOperands, rettys); + } +}]>; + +def : ControlFlowOp<"affine", "AffineIfOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return createAffineIfWithShadows(op, builder, gutils, + cast(original), + remappedOperands, rettys); + } +}]>; + +def : RegionTerminatorOp<"affine", "AffineYieldOp">; +def : ReadOnlyIdentityOp<"affine", "AffineLoadOp", [0]>; +def : ReadOnlyIdentityOp<"affine", "AffineVectorLoadOp", [0]>; +def : MemoryIdentityOp<"affine", "AffineStoreOp", [1], [0]>; +def : MemoryIdentityOp<"affine", "AffineVectorStoreOp", [1], [0]>; diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index bb713ef61799..eb0294b4d24d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -1,39 +1,4 @@ -class MLIRDerivative resultOps> { - string dialect = dialect_; - string opName = opName_; - dag PatternToMatch = patternToMatch; - list ArgDerivatives = resultOps; -} - -class Operation { - bit usesPrimal = usesPrimal_; - bit usesShadow = usesShadow_; - bit usesCustom = usesCustom_; -} - -class DiffeRetIndex indices_> { - list indices = indices_; -} -def DiffeRet : DiffeRetIndex<[-1]>; - -class Inst : Operation { - string name = mnemonic; - string dialect = dialect_; -} -class ArithInst : Inst; - -def AddF : ArithInst<"arith::AddFOp">; -def SubF : ArithInst<"arith::SubFOp">; -def NegF : ArithInst<"arith::NegFOp">; -def MulF : ArithInst<"arith::MulFOp">; -def DivF : ArithInst<"arith::DivFOp">; -def RemF : ArithInst<"arith::RemFOp">; - -def CheckedMulF : ArithInst<"arith::MulFOp">; -def CheckedDivF : ArithInst<"arith::DivFOp">; - -def Op { -} +include "Common.td" def : MLIRDerivative<"arith", "AddFOp", (Op $x, $y), [ @@ -63,6 +28,12 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y), [ (CheckedDivF (DiffeRet), $y), (NegF (MulF (CheckedDivF (DiffeRet), $y), (DivF $x, $y))) - ] - // (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y)) + ], + (CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y)) >; + +def ExtF : ArithInst<"ExtFOp">; +def TruncF : ArithInst<"TruncFOp">; + +def : ReadOnlyIdentityOp<"arith", "TruncFOp", [0], (Op $x), [(ExtF (TypeOf $x), (DiffeRet))]>; +def : ReadOnlyIdentityOp<"arith", "ExtFOp", [0], (Op $x), [(TruncF (TypeOf $x), (DiffeRet))]>; diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 30feca7ba82d..058f49e87ac9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -44,7 +44,44 @@ class FloatTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } +}; + +class TensorTypeInterface + : public AutoDiffTypeInterface::ExternalModel { +public: + Value createNullValue(Type self, OpBuilder &builder, Location loc) const { + auto tenType = self.cast(); + auto ET = tenType.getElementType(); + size_t num = 1; + for (auto sz : tenType.getShape()) + num *= sz; + APFloat apvalue(ET.cast().getFloatSemantics(), 0); + SmallVector supportedValues(num, apvalue); + auto attr = DenseElementsAttr::get(tenType, supportedValues); + return builder.create(loc, tenType, attr); + } + + Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a, + Value b) const { + return builder.create(loc, a, b); + } + + Type getShadowType(Type self, unsigned width) const { + assert(width == 1 && "unsupported width != 1"); + return self; + } + + bool isMutable(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } }; template @@ -68,7 +105,11 @@ class IntegerTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + return failure(); + } }; } // namespace @@ -81,5 +122,7 @@ void mlir::enzyme::registerBuiltinDialectAutoDiffInterface( Float64Type::attachInterface(*context); IntegerType::attachInterface>(*context); IndexType::attachInterface>(*context); + UnrankedTensorType::attachInterface(*context); + RankedTensorType::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..8f40db9d834d --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CFAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,41 @@ +//===- SCFAutoDiffOpInterfaceImpl.cpp - Interface external model ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR SCF dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/EnzymeLogic.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/CFDerivatives.inc" +} // namespace + +void mlir::enzyme::registerCFDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, cf::ControlFlowDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td new file mode 100644 index 000000000000..0b522e72ccf2 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td @@ -0,0 +1,5 @@ +include "Common.td" + +def : BranchOp<"cf", "CondBranchOp">; +def : BranchOp<"cf", "BranchOp">; +def : BranchOp<"cf", "SwitchOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index a41ee2133c68..3508e71d9adc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -1,22 +1,68 @@ +set(LLVM_TARGET_DEFINITIONS AffineDerivatives.td) +enzyme_tablegen(AffineDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(AffineDerivativesIncGen) set(LLVM_TARGET_DEFINITIONS ArithDerivatives.td) enzyme_tablegen(ArithDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(ArithDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS LLVMDerivatives.td) +enzyme_tablegen(LLVMDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(LLVMDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS NVVMDerivatives.td) +enzyme_tablegen(NVVMDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(NVVMDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS SCFDerivatives.td) +enzyme_tablegen(SCFDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(SCFDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS CFDerivatives.td) +enzyme_tablegen(CFDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(CFDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS MemRefDerivatives.td) +enzyme_tablegen(MemRefDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(MemRefDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS MathDerivatives.td) +enzyme_tablegen(MathDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(MathDerivativesIncGen) + +set(LLVM_TARGET_DEFINITIONS FuncDerivatives.td) +enzyme_tablegen(FuncDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(FuncDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations + AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp + CoreDialectsAutoDiffImplementations.cpp LLVMAutoDiffOpInterfaceImpl.cpp + NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp + FuncAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp + CFAutoDiffOpInterfaceImpl.cpp + MathAutoDiffOpInterfaceImpl.cpp DEPENDS MLIRAutoDiffOpInterfaceIncGen + AffineDerivativesIncGen ArithDerivativesIncGen + LLVMDerivativesIncGen + FuncDerivativesIncGen + NVVMDerivativesIncGen + SCFDerivativesIncGen + CFDerivativesIncGen + MemRefDerivativesIncGen + MathDerivativesIncGen LINK_LIBS PUBLIC MLIRArithDialect + MLIRFuncDialect MLIRLLVMDialect MLIRMemRefDialect MLIREnzymeAutoDiffInterface diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td new file mode 100644 index 000000000000..099e614b8bcd --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -0,0 +1,125 @@ +#ifndef ENZYME_MLIR_IMPLEMENTATIONS_COMMON +#define ENZYME_MLIR_IMPLEMENTATIONS_COMMON + +class InactiveOp { + string dialect = dialect_; + string opName = opName_; +} + +class AllocationOp { + string dialect = dialect_; + string opName = opName_; +} + +class ControlFlowOp { + string dialect = dialect_; + string opName = opName_; + string impl = impl_; +} + + +def Unimplemented { + +} + +class MemoryIdentityOp ptrargs_, list storedargs_ = [], dag patternToMatch=(Unimplemented), list reverse_ = []> { + string dialect = dialect_; + string opName = opName_; + dag PatternToMatch = patternToMatch; + list ptrargs = ptrargs_; + list storedargs = storedargs_; + list reverse = reverse_; +} + +class ReadOnlyIdentityOp ptrargs_, dag patternToMatch=(Unimplemented), list reverse_ = []> : MemoryIdentityOp; + +class ReturnOp { + string dialect = dialect_; + string opName = opName_; +} + +class BranchOp { + string dialect = dialect_; + string opName = opName_; +} + +class RegionTerminatorOp { + string dialect = dialect_; + string opName = opName_; +} + +class ForwardFromSummedReverseInternal { + int unused = unused_; +} +def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>; + + +class MLIRDerivative resultOps, dag forwardOps=(ForwardFromSummedReverse)> { + string dialect = dialect_; + string opName = opName_; + dag PatternToMatch = patternToMatch; + list ArgDerivatives = resultOps; + dag ArgDuals = forwardOps; +} + +class Operation { + bit usesPrimal = usesPrimal_; + bit usesShadow = usesShadow_; + bit usesCustom = usesCustom_; +} + +class DiffeRetIndex indices_> { + list indices = indices_; +} +def DiffeRet : DiffeRetIndex<[-1]>; + +def Shadow : Operation { +} + +class GlobalExpr : Operation{ + string value = val; +} + +class Inst : Operation { + string name = mnemonic; + string dialect = dialect_; +} + +def Op { +} + +def SelectIfActive : Operation { + +} + +class ConstantFP : Operation { + string value = val; + string dialect = dialect_; + string opName = op_; + string type = type_; +} + +def ResultTypes : GlobalExprgetResultTypes()">; + +def TypeOf : Operation { +} + +class ArithInst : Inst; +class MathInst : Inst; + +def AddF : ArithInst<"AddFOp">; +def SubF : ArithInst<"SubFOp">; +def NegF : ArithInst<"NegFOp">; +def MulF : ArithInst<"MulFOp">; +def DivF : ArithInst<"DivFOp">; +def RemF : ArithInst<"RemFOp">; + +def CheckedMulF : ArithInst<"MulFOp">; +def CheckedDivF : ArithInst<"DivFOp">; + +def CosF : MathInst<"CosOp">; +def SinF : MathInst<"SinOp">; +def ExpF : MathInst<"ExpOp">; +def SqrtF : MathInst<"SqrtOp">; + +#endif // ENZYME_MLIR_IMPLEMENTATIONS_COMMON diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp new file mode 100644 index 000000000000..ade1be7e6406 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -0,0 +1,428 @@ +//===- CoreDialectsAutoDiffImplementations.cpp ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains common utilities for the external model implementation of +// the automatic differentiation op interfaces for upstream MLIR dialects. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" + +using namespace mlir; +using namespace mlir::enzyme; + +mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type, + llvm::StringRef value) { + using namespace mlir; + if (auto T = dyn_cast(type)) { + size_t num = 1; + for (auto sz : T.getShape()) + num *= sz; + APFloat apvalue(T.getElementType().cast().getFloatSemantics(), + value); + SmallVector supportedValues(num, apvalue); + return DenseFPElementsAttr::get(type.cast(), supportedValues); + } + auto T = cast(type); + APFloat apvalue(T.getFloatSemantics(), value); + return FloatAttr::get(T, apvalue); +} + +void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, + OpBuilder &builder, + MGradientUtils *gutils) { + auto newInst = gutils->getNewFromOriginal(inst); + + auto binst = cast(inst); + + // TODO generalize to cloneWithNewBlockArgs interface + SmallVector newVals; + + SmallVector segSizes; + // Keep non-differentiated, non-forwarded operands + size_t non_forwarded = 0; + for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { + auto ops = binst.getSuccessorOperands(i).getForwardedOperands(); + if (ops.empty()) + continue; + non_forwarded = ops.getBeginOperandIndex(); + break; + } + + for (size_t i = 0; i < non_forwarded; i++) + newVals.push_back(gutils->getNewFromOriginal(binst->getOperand(i))); + + segSizes.push_back(newVals.size()); + for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { + size_t cur = newVals.size(); + auto ops = binst.getSuccessorOperands(i).getForwardedOperands(); + for (auto &&[idx, op] : llvm::enumerate(ops)) { + auto arg = + *binst.getSuccessorBlockArgument(ops.getBeginOperandIndex() + idx); + newVals.push_back(gutils->getNewFromOriginal(op)); + if (!gutils->isConstantValue(arg)) { + if (!gutils->isConstantValue(op)) { + newVals.push_back(gutils->invertPointerM(op, builder)); + } else { + Type retTy = + arg.getType().cast().getShadowType(); + auto toret = retTy.cast().createNullValue( + builder, op.getLoc()); + newVals.push_back(toret); + } + } + } + cur = newVals.size() - cur; + segSizes.push_back(cur); + } + + SmallVector attrs(newInst->getAttrs()); + bool has_cases = false; + for (auto &attr : attrs) { + if (attr.getName() == "case_operand_segments") { + has_cases = true; + } + } + for (auto &attr : attrs) { + if (attr.getName() == "operandSegmentSizes") { + if (!has_cases) { + attr.setValue(builder.getDenseI32ArrayAttr(segSizes)); + } else { + SmallVector segSlices2(segSizes.begin(), segSizes.begin() + 2); + segSlices2.push_back(0); + for (size_t i = 2; i < segSizes.size(); i++) + segSlices2[2] += segSizes[i]; + attr.setValue(builder.getDenseI32ArrayAttr(segSlices2)); + } + } + if (attr.getName() == "case_operand_segments") { + SmallVector segSlices2(segSizes.begin() + 2, segSizes.end()); + attr.setValue(builder.getDenseI32ArrayAttr(segSlices2)); + } + } + + gutils->getNewFromOriginal(inst->getBlock()) + ->push_back( + newInst->create(newInst->getLoc(), newInst->getName(), TypeRange(), + newVals, attrs, OpaqueProperties(nullptr), + newInst->getSuccessors(), newInst->getNumRegions())); + gutils->erase(newInst); + return; +} + +static bool contains(ArrayRef ar, int v) { + for (auto a : ar) { + if (a == v) { + return true; + } + } + return false; +} + +LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils, + ArrayRef storedVals) { + auto iface = cast(orig); + + SmallVector newOperands; + newOperands.reserve(orig->getNumOperands()); + for (OpOperand &operand : orig->getOpOperands()) { + if (iface.isArgInactive(operand.getOperandNumber())) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + } else { + if (gutils->isConstantValue(operand.get())) { + + if (contains(storedVals, operand.getOperandNumber())) { + if (auto iface = + dyn_cast(operand.get().getType())) { + if (!iface.isMutable()) { + Type retTy = iface.getShadowType(); + auto toret = retTy.cast().createNullValue( + builder, operand.get().getLoc()); + newOperands.push_back(toret); + continue; + } + } + } + orig->emitError() + << "Unsupported constant arg to memory identity forward " + "handler(opidx=" + << operand.getOperandNumber() << ", op=" << operand.get() << ")\n"; + return failure(); + } + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + } + + // Assuming shadows following the originals are fine. + // TODO: consider extending to have a ShadowableTerminatorOpInterface + Operation *primal = gutils->getNewFromOriginal(orig); + Operation *shadow = builder.clone(*primal); + shadow->setOperands(newOperands); + for (auto &&[oval, sval] : + llvm::zip(orig->getResults(), shadow->getResults())) { + gutils->setDiffe(oval, sval, builder); + } + + return success(); +} + +LogicalResult mlir::enzyme::detail::allocationForwardHandler( + Operation *orig, OpBuilder &builder, MGradientUtils *gutils, bool zero) { + + Operation *primal = gutils->getNewFromOriginal(orig); + Operation *shadow = builder.clone(*primal); + + Value shadowRes = shadow->getResult(0); + + gutils->setDiffe(orig->getResult(0), shadowRes, builder); + gutils->eraseIfUnused(orig); + + if (zero) { + // Fill with zeros + if (auto iface = dyn_cast(shadowRes.getType())) { + return iface.zeroInPlace(builder, orig->getLoc(), shadowRes); + } else { + orig->emitError() << "Type " << shadowRes.getType() + << " does not implement " + "AutoDiffTypeInterface"; + return failure(); + } + } + return success(); +} + +void mlir::enzyme::detail::returnReverseHandler(Operation *op, + OpBuilder &builder, + MGradientUtilsReverse *gutils) { + size_t num_out = 0; + for (auto act : gutils->RetDiffeTypes) { + if (act == DIFFE_TYPE::OUT_DIFF) + num_out++; + } + + size_t idx = 0; + auto args = gutils->newFunc->getRegions().begin()->begin()->getArguments(); + + for (auto &&[op, act] : llvm::zip(op->getOperands(), gutils->RetDiffeTypes)) { + if (act == DIFFE_TYPE::OUT_DIFF) { + if (!gutils->isConstantValue(op)) { + auto d_out = args[args.size() - num_out + idx]; + gutils->addToDiffe(op, d_out, builder); + } + idx++; + } + } +} + +void mlir::enzyme::detail::regionTerminatorForwardHandler( + Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { + auto parentOp = origTerminator->getParentOp(); + + llvm::SmallDenseSet operandsToShadow; + if (auto termIface = + dyn_cast(origTerminator)) { + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(origTerminator->getNumOperands(), Attribute()), + successors); + + for (auto &successor : successors) { + OperandRange operandRange = termIface.getSuccessorOperands(successor); + ValueRange targetValues = successor.isParent() + ? parentOp->getResults() + : successor.getSuccessorInputs(); + assert(operandRange.size() == targetValues.size()); + for (auto &&[i, target] : llvm::enumerate(targetValues)) { + if (!gutils->isConstantValue(target)) + operandsToShadow.insert(operandRange.getBeginOperandIndex() + i); + } + } + } else { + assert(parentOp->getNumResults() == origTerminator->getNumOperands()); + for (auto res : parentOp->getResults()) { + if (!gutils->isConstantValue(res)) + operandsToShadow.insert(res.getResultNumber()); + } + } + + SmallVector newOperands; + newOperands.reserve(origTerminator->getNumOperands() + + operandsToShadow.size()); + for (OpOperand &operand : origTerminator->getOpOperands()) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + if (operandsToShadow.contains(operand.getOperandNumber())) + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + + // Assuming shadows following the originals are fine. + // TODO: consider extending to have a ShadowableTerminatorOpInterface + Operation *replTerminator = gutils->getNewFromOriginal(origTerminator); + replTerminator->setOperands(newOperands); +} + +LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils) { + + // For all operands that are forwarded to the body, if they are active, also + // add the shadow as operand. + auto regionBranchOp = dyn_cast(op); + if (!regionBranchOp) { + op->emitError() << " RegionBranchOpInterface not implemented for " << *op + << "\n"; + return failure(); + } + + // TODO: we may need to record, for every successor, which of its inputs + // need a shadow to recreate the body correctly. + llvm::SmallDenseSet operandPositionsToShadow; + llvm::SmallDenseSet resultPositionsToShadow; + + SmallVector entrySuccessors; + regionBranchOp.getEntrySuccessorRegions( + SmallVector(op->getNumOperands(), Attribute()), + entrySuccessors); + + for (const RegionSuccessor &successor : entrySuccessors) { + + OperandRange operandRange = + regionBranchOp.getEntrySuccessorOperands(successor); + + ValueRange targetValues = successor.isParent() + ? op->getResults() + : successor.getSuccessorInputs(); + + // Need to know which of the arguments are being forwarded to from + // operands. + for (auto &&[i, regionValue, operand] : + llvm::enumerate(targetValues, operandRange)) { + if (gutils->isConstantValue(regionValue)) + continue; + operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); + if (successor.isParent()) + resultPositionsToShadow.insert(i); + } + } + + for (auto res : op->getResults()) + if (!gutils->isConstantValue(res)) + resultPositionsToShadow.insert(res.getResultNumber()); + + return controlFlowForwardHandler( + op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow); +} + +LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils, + const llvm::SmallDenseSet &operandPositionsToShadow, + const llvm::SmallDenseSet &resultPositionsToShadow) { + // For all active results, add shadow types. + // For now, assuming all results are relevant. + Operation *newOp = gutils->getNewFromOriginal(op); + SmallVector newOpResultTypes; + newOpResultTypes.reserve(op->getNumResults() * 2); + for (auto result : op->getResults()) { + // TODO only if used (can we DCE the primal after having done the + // derivative). + newOpResultTypes.push_back(result.getType()); + if (!gutils->isConstantValue(result)) { + assert(resultPositionsToShadow.count(result.getResultNumber())); + } + if (!resultPositionsToShadow.count(result.getResultNumber())) + continue; + auto typeIface = dyn_cast(result.getType()); + if (!typeIface) { + op->emitError() << " AutoDiffTypeInterface not implemented for " + << result.getType() << "\n"; + return failure(); + } + newOpResultTypes.push_back(typeIface.getShadowType()); + } + + SmallVector newOperands; + newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size()); + for (OpOperand &operand : op->getOpOperands()) { + newOperands.push_back(gutils->getNewFromOriginal(operand.get())); + if (operandPositionsToShadow.contains(operand.getOperandNumber())) + newOperands.push_back(gutils->invertPointerM(operand.get(), builder)); + } + // We are assuming the op can forward additional operands, listed + // immediately after the original operands, to the same regions. + // ^^ + // Our interface guarantees this. + // We also assume that the region-holding op returns all of the values + // yielded by terminators, and only those values. + + auto iface = dyn_cast(op); + if (!iface) { + op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for " + << *op << "\n"; + return failure(); + } + Operation *replacement = iface.createWithShadows( + builder, gutils, op, newOperands, newOpResultTypes); + assert(replacement->getNumResults() == newOpResultTypes.size()); + for (auto &&[region, replacementRegion] : + llvm::zip(newOp->getRegions(), replacement->getRegions())) { + replacementRegion.takeBody(region); + } + + // Inject the mapping for the new results into GradientUtil's shadow + // table. + SmallVector reps; + size_t idx = 0; + for (Value r : op->getResults()) { + // TODO only if used + reps.push_back(replacement->getResult(idx)); + idx++; + if (!gutils->isConstantValue(r)) { + auto inverted = gutils->invertedPointers.lookupOrNull(r); + assert(inverted); + gutils->invertedPointers.map(r, replacement->getResult(idx)); + inverted.replaceAllUsesWith(replacement->getResult(idx)); + gutils->erase(inverted.getDefiningOp()); + idx++; + } + } + + // Differentiate body. + for (auto &origRegion : op->getRegions()) { + for (auto &origBlock : origRegion) { + for (Operation &o : origBlock) { + if (failed(gutils->visitChild(&o))) { + return failure(); + } + } + } + } + + // Replace all uses of original results + gutils->replaceOrigOpWith(op, reps); + gutils->erase(newOp); + + return success(); +} + +void mlir::enzyme::registerCoreDialectAutodiffInterfaces( + DialectRegistry ®istry) { + enzyme::registerAffineDialectAutoDiffInterface(registry); + enzyme::registerArithDialectAutoDiffInterface(registry); + enzyme::registerBuiltinDialectAutoDiffInterface(registry); + enzyme::registerLLVMDialectAutoDiffInterface(registry); + enzyme::registerNVVMDialectAutoDiffInterface(registry); + enzyme::registerMathDialectAutoDiffInterface(registry); + enzyme::registerMemRefDialectAutoDiffInterface(registry); + enzyme::registerSCFDialectAutoDiffInterface(registry); + enzyme::registerCFDialectAutoDiffInterface(registry); + enzyme::registerLinalgDialectAutoDiffInterface(registry); + enzyme::registerFuncDialectAutoDiffInterface(registry); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 669b028998c6..cbad734656b1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -12,15 +12,253 @@ // //===----------------------------------------------------------------------===// +#ifndef ENZYMEMLIR_CORE_IMPL_H_ +#define ENZYMEMLIR_CORE_IMPL_H_ + +#include "Interfaces/AutoDiffOpInterface.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/DenseSet.h" + namespace mlir { class DialectRegistry; +class Operation; +class OpBuilder; +class RegionSuccessor; namespace enzyme { +class MGradientUtils; +class MGradientUtilsReverse; + +namespace detail { +// Non-template implementation of +// AutoDiffUsingControlFlow::createForwardModeTangent. + +LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +LogicalResult controlFlowForwardHandler( + Operation *op, OpBuilder &builder, MGradientUtils *gutils, + const llvm::SmallDenseSet &operandPositionsToShadow, + const llvm::SmallDenseSet &resultPositionsToShadow); + +// Implements forward-mode differentiation of branching operations. +// Assumes that successive shadows are legal +void branchingForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements forward-mode differentiation of region-terminator operations. +// Assumes that successive shadows are legal +void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils); + +// Implements reverse-mode differentiation of return operations. +void returnReverseHandler(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils); + +// Implements forward-mode differentiation of read-only (including read-none) +// operations which do not perform computation +LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, + ArrayRef storedVals); + +// Implements shadow initialization differentiation of allocation +LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, bool zero); + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their control flow and differentiating the +// nested operations. +template +class AutoDiffUsingControlFlow + : public AutoDiffOpInterface::ExternalModel, + OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + return controlFlowForwardHandler(op, builder, gutils); + } +}; + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their branching properties. +template +class AutoDiffUsingBranch + : public AutoDiffOpInterface::ExternalModel, + OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + branchingForwardHandler(op, builder, gutils); + return success(); + } +}; + +// Implements the forward autodiff interface for operations whose derivatives +// are can be inferred by analyzing their region terminator properties. +template +class AutoDiffUsingRegionTerminator + : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingRegionTerminator, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + regionTerminatorForwardHandler(op, builder, gutils); + return success(); + } +}; + +template +class NoopRevAutoDiffInterface + : public ReverseAutoDiffOpInterface::ExternalModel< + NoopRevAutoDiffInterface, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const {} + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + +template +class ReturnRevAutoDiffInterface + : public ReverseAutoDiffOpInterface::ExternalModel< + ReturnRevAutoDiffInterface, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + returnReverseHandler(op, builder, gutils); + } + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + +// Implements the forward autodiff interface for operations which are +// read only and identity like (aka not computing sin of mem read). +template +class AutoDiffUsingMemoryIdentity + : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingMemoryIdentity, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + + return memoryIdentityForwardHandler( + op, builder, gutils, std::initializer_list{storedvals...}); + } +}; + +// Implements the forward autodiff interface for operations which are +// allocation like +template +class AutoDiffUsingAllocationFwd : public AutoDiffOpInterface::ExternalModel< + AutoDiffUsingAllocationFwd, OpTy> { +public: + LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, + MGradientUtils *gutils) const { + + return allocationForwardHandler(op, builder, gutils, /*zero*/ false); + } +}; + +// Implements the reverse autodiff interface for operations which are +// allocation like +template +class AutoDiffUsingAllocationRev + : public ReverseAutoDiffOpInterface::ExternalModel< + AutoDiffUsingAllocationRev, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const {} + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const { + (void)allocationForwardHandler(op, builder, (MGradientUtils *)gutils, + /*zero*/ true); + } +}; +} // namespace detail + +// Registers AutoDiffUsingControlFlow for the given op. +template +void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); +} +// Registers AutoDiffUsingBranch for the given op. +template +void registerAutoDiffUsingBranchInterface(MLIRContext &context) { + OpTy::template attachInterface>(context); + OpTy::template attachInterface>( + context); +} +// Registers AutoDiffUsingRegionTerminator for the given op. +template +void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); + OpTy::template attachInterface>( + context); +} +// Registers registerAutoDiffUsingReturnInterface for the given op. +template +void registerAutoDiffUsingReturnInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); + OpTy::template attachInterface>( + context); +} +// Registers AutoDiffUsingMemoryIdentity for the given op. +template +void registerAutoDiffUsingMemoryIdentityInterface(MLIRContext &context) { + OpTy::template attachInterface< + detail::AutoDiffUsingMemoryIdentity>(context); +} +// Registers AutoDiffUsingAllocation for the given op. +template +void registerAutoDiffUsingAllocationInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); + OpTy::template attachInterface>( + context); +} + +// Interface registration hooks for individual upstream dialects. +void registerAffineDialectAutoDiffInterface(DialectRegistry ®istry); void registerArithDialectAutoDiffInterface(DialectRegistry ®istry); void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry); void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry); +void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry); void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry); void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); +void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); +void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); +void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); + +void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); + +mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value); } // namespace enzyme } // namespace mlir + +#endif diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..dddce795adfd --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,37 @@ +//===- FuncAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/Ops.h" +#include "mlir/IR/TypeSupport.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/FuncDerivatives.inc" +} // namespace + +void mlir::enzyme::registerFuncDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td new file mode 100644 index 000000000000..005246887fdf --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td @@ -0,0 +1,3 @@ +include "Common.td" + +def : ReturnOp<"func", "ReturnOp">; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 3c60fd421c7a..264278e97ef9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -23,53 +23,17 @@ using namespace mlir; using namespace mlir::enzyme; namespace { -struct LoadOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto loadOp = cast(op); - if (!gutils->isConstantValue(loadOp)) { - Type shadowType = - cast(loadOp.getType()).getShadowType(); - mlir::Value res = builder.create( - loadOp.getLoc(), shadowType, - gutils->invertPointerM(loadOp.getAddr(), builder)); - gutils->setDiffe(loadOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; +#include "Implementations/LLVMDerivatives.inc" -struct StoreOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto storeOp = cast(op); - if (!gutils->isConstantValue(storeOp.getAddr())) { - builder.create( - storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), - gutils->invertPointerM(storeOp.getAddr(), builder)); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - -struct AllocaOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto allocOp = cast(op); - if (!gutils->isConstantValue(allocOp)) { - Operation *nop = gutils->cloneWithNewOperands(builder, op); - gutils->setDiffe(allocOp, nop->getResult(0), builder); - } - gutils->eraseIfUnused(op); - return success(); +struct InlineAsmActivityInterface + : public ActivityOpInterface::ExternalModel { + bool isInactive(Operation *op) const { + auto asmOp = cast(op); + auto str = asmOp.getAsmString(); + return str.contains("cpuid") || str.contains("exit"); } + bool isArgInactive(Operation *op, size_t) const { return isInactive(op); } }; class PointerTypeInterface @@ -91,16 +55,20 @@ class PointerTypeInterface return self; } - bool requiresShadow(Type self) const { return true; } + bool isMutable(Type self) const { return true; } + + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + // TODO inspect val and memset corresponding size + return failure(); + } }; } // namespace void mlir::enzyme::registerLLVMDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) { - LLVM::LoadOp::attachInterface(*context); - LLVM::StoreOp::attachInterface(*context); - LLVM::AllocaOp::attachInterface(*context); LLVM::LLVMPointerType::attachInterface(*context); + registerInterfaces(context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td new file mode 100644 index 000000000000..e77e88aea47f --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -0,0 +1,28 @@ +include "Common.td" + +def : MemoryIdentityOp<"LLVM", "StoreOp", [1], [0]>; +def : InactiveOp<"LLVM", "SIToFPOp">; +def : InactiveOp<"LLVM", "UIToFPOp">; +def : InactiveOp<"LLVM", "FPToSIOp">; +def : InactiveOp<"LLVM", "FPToUIOp">; +def : InactiveOp<"LLVM", "AssumeOp">; +def : InactiveOp<"LLVM", "StackSaveOp">; +def : InactiveOp<"LLVM", "StackRestoreOp">; +def : InactiveOp<"LLVM", "LifetimeStartOp">; +def : InactiveOp<"LLVM", "LifetimeEndOp">; +def : InactiveOp<"LLVM", "Prefetch">; +def : InactiveOp<"LLVM", "MemsetOp">; + +def : InactiveOp<"LLVM", "UndefOp">; +def : InactiveOp<"LLVM", "ConstantOp">; +def : InactiveOp<"LLVM", "UnreachableOp">; + + +def : ReadOnlyIdentityOp<"LLVM", "LoadOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "AddrSpaceCastOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "BitcastOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "GEPOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>; +def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; + +def : AllocationOp<"LLVM", "AllocaOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 93488a07fd6e..db9f4b08d6e1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -28,7 +28,6 @@ #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -80,9 +79,7 @@ struct GenericOpInterfaceReverse cast(gutils->getNewFromOriginal(linalgOp)); // Replace the op by a linalg.generic op if necessary - // TODO : IRRewriter rewriter(builder.getContext()/*, - // builder.getListener()*/); - ConversionPatternRewriter rewriter(builder.getContext()); + IRRewriter rewriter(builder.getContext(), builder.getListener()); auto failiureOrLinalgOp = generalizeNamedOp(rewriter, newOp); if (!failed(failiureOrLinalgOp)) { linalg::GenericOp replacement = failiureOrLinalgOp.value(); @@ -136,7 +133,7 @@ struct GenericOpInterfaceReverse linalgOp.getNumLoops(), utils::IteratorType::parallel}; for (OpOperand &output : linalgOp.getDpsInitsMutable()) { - if (!gutils->hasInvertPointer(output.get())) { + if (gutils->isConstantValue(output.get())) { continue; } indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&output)); @@ -146,7 +143,7 @@ struct GenericOpInterfaceReverse } for (OpOperand *input : linalgOp.getDpsInputOperands()) { - if (!gutils->hasInvertPointer(input->get())) { + if (gutils->isConstantValue(input->get())) { continue; } indexingMaps.push_back(linalgOp.getMatchingIndexingMap(input)); @@ -168,8 +165,13 @@ struct GenericOpInterfaceReverse StringAttr()); int numInputs = inputs.size(); - auto buildFuncReturnOp = [numInputs](OpBuilder &builder, Location loc, - SmallVector retargs) { + auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder, + Block *oBB) { + auto loc = oBB->rbegin()->getLoc(); + SmallVector retargs; + for (auto arg : oBB->getArguments()) { + retargs.push_back(gutils->invertPointerM(arg, builder)); + } builder.create( loc, ValueRange{retargs}.take_front(numInputs)); return; @@ -195,9 +197,8 @@ struct GenericOpInterfaceReverse return std::make_pair(pushCache, popCache); }; - gutils->Logic.differentiate( - gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(), - /*parentRegion=*/false, buildFuncReturnOp, hook); + gutils->Logic.differentiate(gutils, *linalgOp.getBlock()->getParent(), + adjoint.getRegion(), buildFuncReturnOp, hook); auto newOpYield = cast( cast(newOp).getBodyRegion().front().getTerminator()); diff --git a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..2833eeb44726 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,38 @@ +//===- ArithAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/Ops.h" +#include "mlir/IR/TypeSupport.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/MathDerivatives.inc" +} // namespace + +void mlir::enzyme::registerMathDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, math::MathDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td new file mode 100644 index 000000000000..71db1a8574ac --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td @@ -0,0 +1,17 @@ +include "Common.td" + +def : MLIRDerivative<"math", "CosOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (NegF (SinF $x))) + ] + >; +def : MLIRDerivative<"math", "ExpOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (ExpF $x)) + ] + >; +def : MLIRDerivative<"math", "SinOp", (Op $x), + [ + (CheckedMulF (DiffeRet), (CosF $x)) + ] + >; diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 010f9d997005..cd21c5c548b9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -17,55 +17,19 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" -// TODO: We need a way to zero out a memref (which linalg.fill does), but -// ideally we wouldn't depend on the linalg dialect. -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" +// TODO: We need a way to zero out a memref (which linalg.fill does), but +// ideally we wouldn't depend on the linalg dialect. +#include "mlir/Dialect/Linalg/IR/Linalg.h" + using namespace mlir; using namespace mlir::enzyme; namespace { -struct LoadOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto loadOp = cast(op); - if (!gutils->isConstantValue(loadOp)) { - SmallVector inds; - for (auto ind : loadOp.getIndices()) - inds.push_back(gutils->getNewFromOriginal(ind)); - mlir::Value res = builder.create( - loadOp.getLoc(), gutils->invertPointerM(loadOp.getMemref(), builder), - inds); - gutils->setDiffe(loadOp, res, builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - -struct StoreOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto storeOp = cast(op); - if (!gutils->isConstantValue(storeOp.getMemref())) { - SmallVector inds; - for (auto ind : storeOp.getIndices()) - inds.push_back(gutils->getNewFromOriginal(ind)); - builder.create( - storeOp.getLoc(), gutils->invertPointerM(storeOp.getValue(), builder), - gutils->invertPointerM(storeOp.getMemref(), builder), inds); - } - gutils->eraseIfUnused(op); - return success(); - } -}; +#include "Implementations/MemRefDerivatives.inc" struct LoadOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel(loadOp.getType())) { - if (gutils->hasInvertPointer(loadOp) && - gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(loadOp) && + !gutils->isConstantValue(memref)) { Value gradient = gutils->invertPointerM(loadOp, builder); Value memrefGradient = gutils->invertPointerM(memref, builder); @@ -106,8 +70,8 @@ struct LoadOpInterfaceReverse Value memref = loadOp.getMemref(); ValueRange indices = loadOp.getIndices(); if (auto iface = dyn_cast(loadOp.getType())) { - if (gutils->hasInvertPointer(loadOp) && - gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(loadOp) && + !gutils->isConstantValue(memref)) { OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); SmallVector caches; for (Value v : indices) { @@ -140,28 +104,33 @@ struct StoreOpInterfaceReverse Value memref = storeOp.getMemref(); // ValueRange indices = storeOp.getIndices(); - if (auto iface = dyn_cast(val.getType())) { - if (gutils->hasInvertPointer(memref)) { - OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); + auto iface = cast(val.getType()); - Value memrefGradient = gutils->invertPointerM(memref, builder); + if (!gutils->isConstantValue(memref)) { + OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); - SmallVector retrievedArguments; - for (Value cache : caches) { - Value retrievedValue = gutils->popCache(cache, builder); - retrievedArguments.push_back(retrievedValue); - } + Value memrefGradient = gutils->invertPointerM(memref, builder); - Value loadedGradient = - builder.create(storeOp.getLoc(), memrefGradient, - ArrayRef(retrievedArguments)); - Value addedGradient = loadedGradient; - if (gutils->hasInvertPointer(val)) { - Value gradient = gutils->invertPointerM(val, builder); - addedGradient = iface.createAddOp(builder, storeOp.getLoc(), gradient, - loadedGradient); + SmallVector retrievedArguments; + for (Value cache : caches) { + Value retrievedValue = gutils->popCache(cache, builder); + retrievedArguments.push_back(retrievedValue); + } + + if (!iface.isMutable()) { + if (!gutils->isConstantValue(val)) { + Value loadedGradient = builder.create( + storeOp.getLoc(), memrefGradient, + ArrayRef(retrievedArguments)); + gutils->addToDiffe(val, loadedGradient, builder); } - gutils->mapInvertPointer(val, addedGradient, builder); + + auto zero = + cast(gutils->getShadowType(val.getType())) + .createNullValue(builder, op->getLoc()); + + builder.create(storeOp.getLoc(), zero, memrefGradient, + ArrayRef(retrievedArguments)); } } } @@ -173,7 +142,7 @@ struct StoreOpInterfaceReverse ValueRange indices = storeOp.getIndices(); Value val = storeOp.getValue(); if (auto iface = dyn_cast(val.getType())) { - if (gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(memref)) { OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); SmallVector caches; for (Value v : indices) { @@ -195,38 +164,6 @@ struct StoreOpInterfaceReverse } }; -struct AllocOpInterfaceReverse - : public ReverseAutoDiffOpInterface::ExternalModel { - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const {} - - SmallVector cacheValues(Operation *op, - MGradientUtilsReverse *gutils) const { - return SmallVector(); - } - - void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const { - auto allocOp = cast(op); - auto newAllocOp = cast(gutils->getNewFromOriginal(op)); - - Value shadow = builder.create( - op->getLoc(), newAllocOp.getType(), newAllocOp.getDynamicSizes()); - // Fill with zeros - if (auto iface = dyn_cast( - allocOp.getType().getElementType())) { - Value zero = iface.createNullValue(builder, op->getLoc()); - builder.create(op->getLoc(), zero, shadow); - } else { - op->emitWarning() << "memref.alloc element type does not implement " - "AutoDiffTypeInterface"; - } - gutils->mapShadowValue(allocOp, shadow, builder); - } -}; - struct SubViewOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel< SubViewOpInterfaceReverse, memref::SubViewOp> { @@ -243,32 +180,17 @@ struct SubViewOpInterfaceReverse MGradientUtilsReverse *gutils) const { auto subviewOp = cast(op); auto newSubviewOp = cast(gutils->getNewFromOriginal(op)); - if (gutils->hasInvertPointer(subviewOp.getSource())) { + if (!gutils->isConstantValue(subviewOp.getSource())) { Value shadow = builder.create( op->getLoc(), newSubviewOp.getType(), gutils->invertPointerM(subviewOp.getSource(), builder), newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(), newSubviewOp.getMixedStrides()); - gutils->mapShadowValue(subviewOp, shadow, builder); + gutils->setDiffe(subviewOp, shadow, builder); } } }; -struct AllocOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto allocOp = cast(op); - if (!gutils->isConstantValue(allocOp)) { - Operation *nop = gutils->cloneWithNewOperands(builder, op); - gutils->setDiffe(allocOp, nop->getResult(0), builder); - } - gutils->eraseIfUnused(op); - return success(); - } -}; - class MemRefTypeInterface : public AutoDiffTypeInterface::ExternalModel { @@ -288,21 +210,32 @@ class MemRefTypeInterface return self; } - bool requiresShadow(Type self) const { return true; } + bool isMutable(Type self) const { return true; } + + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, + Value val) const { + auto MT = cast(self); + if (auto iface = dyn_cast(MT.getElementType())) { + if (!iface.isMutable()) { + Value zero = iface.createNullValue(builder, loc); + builder.create(loc, zero, val); + } + } else { + return failure(); + } + return success(); + } }; } // namespace void mlir::enzyme::registerMemRefDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, memref::MemRefDialect *) { - memref::LoadOp::attachInterface(*context); - memref::StoreOp::attachInterface(*context); - memref::AllocOp::attachInterface(*context); + registerInterfaces(context); MemRefType::attachInterface(*context); memref::LoadOp::attachInterface(*context); memref::StoreOp::attachInterface(*context); - memref::AllocOp::attachInterface(*context); memref::SubViewOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td new file mode 100644 index 000000000000..173a22a6e2f1 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefDerivatives.td @@ -0,0 +1,16 @@ +include "Common.td" + +def : MemoryIdentityOp<"memref", "StoreOp", [1], [0]>; +def : ReadOnlyIdentityOp<"memref", "LoadOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "CastOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "CollapseShapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ExpandShapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ReinterpretCastOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ReshapeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "TransposeOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "ViewOp", [0]>; +def : ReadOnlyIdentityOp<"memref", "SubViewOp", [0]>; + +def : InactiveOp<"memref", "DimOp">; +def : AllocationOp<"memref", "AllocOp">; +def : AllocationOp<"memref", "AllocaOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..4d8116ce011b --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,34 @@ +//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream LLVM dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/AutoDiffTypeInterface.h" +#include "Interfaces/GradientUtils.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/NVVMDerivatives.inc" +} // namespace + +void mlir::enzyme::registerNVVMDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, NVVM::NVVMDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td new file mode 100644 index 000000000000..f34dfb564cbc --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/NVVMDerivatives.td @@ -0,0 +1,4 @@ +include "Common.td" + +// TODO in reverse replicate in reverse pass +def : InactiveOp<"NVVM", "Barrier0Op">; diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 52f48ce7e2d1..72fbe2106a53 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -19,83 +19,18 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include using namespace mlir; using namespace mlir::enzyme; namespace { -struct ForOpInterface - : public AutoDiffOpInterface::ExternalModel { - LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, - MGradientUtils *gutils) const { - auto forOp = cast(op); - auto nFor = cast(gutils->getNewFromOriginal(op)); - SmallVector nTypes; - for (auto r : forOp->getResults()) { - // TODO only if used - nTypes.push_back(r.getType()); - if (!gutils->isConstantValue(r)) { - auto adTypeIface = r.getType().dyn_cast(); - if (!adTypeIface) - return failure(); - nTypes.push_back(adTypeIface.getShadowType()); - } - } - SmallVector nArgs; - for (const auto &[initVal, iterArg] : - llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs())) { - // TODO only if used - nArgs.push_back(gutils->getNewFromOriginal(initVal)); - if (!gutils->isConstantValue(iterArg)) - nArgs.push_back(gutils->invertPointerM(initVal, builder)); - } - auto repFor = builder.create( - forOp.getLoc(), gutils->getNewFromOriginal(forOp.getLowerBound()), - gutils->getNewFromOriginal(forOp.getUpperBound()), - gutils->getNewFromOriginal(forOp.getStep()), nArgs); - repFor.getRegion().takeBody(nFor.getRegion()); - - SmallVector reps; - size_t idx = 0; - for (Value r : forOp.getResults()) { - // TODO only if used - reps.push_back(repFor.getResult(idx)); - idx++; - if (!gutils->isConstantValue(r)) { - auto inverted = gutils->invertedPointers.lookupOrNull(r); - assert(inverted); - gutils->invertedPointers.map(r, repFor.getResult(idx)); - inverted.replaceAllUsesWith(repFor.getResult(idx)); - gutils->erase(inverted.getDefiningOp()); - idx++; - } - } - nFor.replaceAllUsesWith(reps); - gutils->erase(nFor); - for (Operation &o : - llvm::make_early_inc_range(forOp.getBody()->without_terminator())) { - if (failed(gutils->visitChild(&o))) - return failure(); - } - Operation *oldYield = repFor.getBody()->getTerminator(); - builder.setInsertionPointToEnd(repFor.getBody()); - SmallVector nYields; - for (const auto &[result, yieldOperand] : - llvm::zip(forOp.getResults(), - forOp.getBody()->getTerminator()->getOperands())) { - // TODO only if used - nYields.push_back(gutils->getNewFromOriginal(yieldOperand)); - if (!gutils->isConstantValue(result)) - nYields.push_back(gutils->invertPointerM(yieldOperand, builder)); - } - Operation *newYield = builder.clone(*oldYield); - newYield->setOperands(nYields); - gutils->erase(oldYield); - return success(); - } -}; +#include "Implementations/SCFDerivatives.inc" struct ForOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel caches) const { auto forOp = cast(op); - SmallVector nArgs; - for (Value v : forOp.getResults()) { - if (auto iface = dyn_cast(v.getType())) { - if (gutils->hasInvertPointer(v)) { - nArgs.push_back(gutils->invertPointerM(v, builder)); - } else { - nArgs.push_back(iface.createNullValue(builder, v.getLoc())); + // Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 + SmallVector resDiffes; + for (OpResult v : forOp.getResults()) { + if (!gutils->isConstantValue(v)) { + auto autoDiffType = cast(v.getType()); + if (!autoDiffType.isMutable()) { + auto prev = gutils->diffe(v, builder); + gutils->zeroDiffe(v, builder); + resDiffes.push_back(prev); + continue; + } + } + resDiffes.push_back(nullptr); + } + + for (auto ® : op->getRegions()) { + auto termIface = + cast(reg.begin()->getTerminator()); + + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + for (auto &successor : successors) { + if (!successor.isParent()) + continue; + OperandRange operandRange = termIface.getSuccessorOperands(successor); + assert(operandRange.size() == resDiffes.size()); + + // There is an assumption here that there is only regions that branch to + // the successor. Specifically, otherwise we would need to + // gutils->addToDiffe select (if came from that result) + for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) { + if (!post) + continue; + if (!gutils->isConstantValue(prev)) + gutils->addToDiffe(prev, post, builder); } } } + // End Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 + + auto start = gutils->popCache(caches[0], builder); + auto end = gutils->popCache(caches[1], builder); + auto step = gutils->popCache(caches[2], builder); + + auto repFor = builder.create(forOp.getLoc(), start, end, step, + ArrayRef()); + // erase scf yield + repFor.getBody()->begin()->erase(); + + for (auto &&[oldReg, newReg] : + llvm::zip(op->getRegions(), repFor->getRegions())) { + + // This code assumes at most one terminating block for each region (lest + // the append happen multiple times) + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { + auto loc = oBB->rbegin()->getLoc(); + + auto idx = repFor.getInductionVar(); + + auto lhs = builder.create(loc, idx, step); + + // This needs to know a condition describing which predecessor this will + // return to, to select the right value Here we use the condition i + + // step >= end to determine the last iteration + + auto condition = builder.create( + loc, arith::CmpIPredicate::sge, lhs, end); + + for (auto [arg, init_arg] : + llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + auto diffe = gutils->diffe(arg, builder); + gutils->zeroDiffe(arg, builder); + + auto zero = cast(diffe.getType()) + .createNullValue(builder, loc); + auto outside = + builder.create(loc, condition, diffe, zero); + auto inside = + builder.create(loc, condition, zero, diffe); + + // For each predecessor, if we came from that predecessor += the + // shadow of the arg [after zero'ing] + if (!gutils->isConstantValue(init_arg)) { + gutils->addToDiffe(init_arg, outside, builder); + } + + if (!gutils->isConstantValue(arg)) { + gutils->addToDiffe(arg, inside, builder); + } + } + } + builder.create(loc); + }; - auto repFor = builder.create( - forOp.getLoc(), gutils->popCache(caches[0], builder), - gutils->popCache(caches[1], builder), - gutils->popCache(caches[2], builder), nArgs); // TODO - repFor.getRegion().begin()->erase(); - - auto buildFuncReturnOp = [](OpBuilder &builder, Location loc, - SmallVector retargs) { - builder.create(loc, retargs); - return; - }; - - gutils->Logic.differentiate(gutils, forOp.getRegion(), repFor.getRegion(), - /*parentRegion=*/false, buildFuncReturnOp, - nullptr); - - // Insert the index which is carried by the scf for op. - Type indexType = IndexType::get(builder.getContext()); - repFor.getRegion().insertArgument((unsigned)0, indexType, forOp.getLoc()); - - for (const auto &[iterOperand, adjResult] : - llvm::zip(forOp.getInitArgs(), repFor.getResults())) { - if (gutils->hasInvertPointer(iterOperand)) { - auto autoDiffType = cast(iterOperand.getType()); - Value before = gutils->invertPointerM(iterOperand, builder); - Value after = autoDiffType.createAddOp(builder, forOp.getLoc(), before, - adjResult); - gutils->mapInvertPointer(iterOperand, after, builder); + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->mapReverseModeBlocks.map(&oBB, &revBB); + } + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->Logic.visitChildren(&oBB, &revBB, gutils); + Block *newBB = gutils->getNewFromOriginal(&oBB); + gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, + buildFuncReturnOp); } } } @@ -182,8 +185,7 @@ struct ForOpInterfaceReverse void mlir::enzyme::registerSCFDialectAutoDiffInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *context, scf::SCFDialect *) { - scf::ForOp::attachInterface(*context); - + registerInterfaces(context); scf::ForOp::attachInterface(*context); }); } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td new file mode 100644 index 000000000000..4c9ee09abcd2 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/SCFDerivatives.td @@ -0,0 +1,50 @@ +include "Common.td" + +def : ControlFlowOp<"scf", "ForOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + scf::ForOpAdaptor adaptor(remappedOperands); + auto repFor = builder.create( + op->getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), adaptor.getInitArgs()); + return repFor; + } +}]>; + +def : ControlFlowOp<"scf", "IfOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + scf::IfOpAdaptor adaptor(remappedOperands); + auto repIf = builder.create( + op->getLoc(), rettys, adaptor.getCondition()); + return repIf; + } +}]>; + +def : ControlFlowOp<"scf", "WhileOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return builder.create(original->getLoc(), rettys, + remappedOperands, original->getAttrs()); + } +}]>; + +def : ControlFlowOp<"scf", "ExecuteRegionOp", [{ + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + auto repIf = builder.create( + op->getLoc(), rettys); + return repIf; + } +}]>; + +def : RegionTerminatorOp<"scf", "YieldOp">; +def : RegionTerminatorOp<"scf", "ConditionOp">; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td index 61099c316cd2..9123dfcaa22e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td @@ -40,6 +40,42 @@ def AutoDiffOpInterface : OpInterface<"AutoDiffOpInterface"> { ]; } +def ControlFlowAutoDiffOpInterface + : OpInterface<"ControlFlowAutoDiffOpInterface"> { + let description = [{ + A differentiable MLIR operation whose forward differentiation rules are + driven by how control flows through the operation. + + There are two general assumptions: + - the operation can communicate additional values along the control flow + edges, which is used to put shadow values immediately after the primal + values; + - all values returned by the operation are yielded by all region-exiting + terminators. + }]; + let cppNamespace = "::mlir::enzyme"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Creates a copy of this operation additionally carrying required shadow + values along control flow edges using the given builder. The `original` + is the operation in the original primal code prior to differentiation, + and this method is supposed to be called on the operation in the cloned + function being constructed. Remapped operands contains a flat list of + operands usable in the cloned function and can be fed to the Adaptor + constructor. + }], + /*retTy=*/"::mlir::Operation *", + /*methodName=*/"createWithShadows", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::enzyme::MGradientUtils *":$gutils, + "::mlir::Operation *":$original, + "::mlir::ValueRange":$remappedOperands, + "::mlir::TypeRange":$returnTypes) + > + ]; +} + def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { let description = [{ A differentiable MLIR operation that is able to emit reverse mode adjoints for itself. @@ -76,4 +112,25 @@ def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> { ]; } +def ActivityOpInterface + : OpInterface<"ActivityOpInterface"> { + let cppNamespace = "::mlir::enzyme"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + }], + /*retTy=*/"bool", + /*methodName=*/"isInactive" + >, + InterfaceMethod< + /*desc=*/[{ + }], + /*retTy=*/"bool", + /*methodName=*/"isArgInactive", + /*args=*/(ins "size_t":$opidx) + > + ]; +} + #endif // ENZYME_MLIR_INTERFACES_AUTODIFFOPINTERFACES diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h index 7c405bde0a99..04d0186c4b1e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.h @@ -16,6 +16,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" namespace mlir { class OpBuilder; diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td index 5fec2643a98c..0f6d38b4ffa4 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td @@ -41,6 +41,14 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { /*methodName=*/"createAddOp", /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, "::mlir::Value":$a, "::mlir::Value":$b) >, + InterfaceMethod< + /*desc=*/[{ + Zero the operation in place + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"zeroInPlace", + /*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, "::mlir::Value":$val) + >, InterfaceMethod< /*desc=*/[{ Returns the type that can contain the adjoint value for this type. If @@ -53,10 +61,10 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { >, InterfaceMethod< /*desc=*/[{ - Returns if the Type needs to be cleared. + Returns whether the type is mutable in place or not. }], /*retTy=*/"bool", - /*methodName=*/"requiresShadow", + /*methodName=*/"isMutable", /*args=*/(ins ) > ]; diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 84b8b8b5d3b3..9b5c007c62ee 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -1,10 +1,15 @@ +#include "llvm/ADT/APSInt.h" + #include "CloneFunction.h" using namespace mlir; using namespace mlir::enzyme; Type getShadowType(Type type, unsigned width) { - return type.cast().getShadowType(width); + if (auto iface = type.dyn_cast()) + return iface.getShadowType(width); + llvm::errs() << " type does not have autodifftypeinterface: " << type << "\n"; + exit(1); } mlir::FunctionType getFunctionTypeForClone( @@ -14,22 +19,26 @@ mlir::FunctionType getFunctionTypeForClone( SmallVector RetTypes; if (returnValue == ReturnType::ArgsWithReturn || returnValue == ReturnType::Return) { - assert(FTy.getNumResults() == 1); - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(0), width)); - } else { - RetTypes.push_back(FTy.getResult(0)); + assert(FTy.getNumResults() >= 1); + for (size_t i = 0; i < FTy.getNumResults(); i++) { + if (ReturnType != DIFFE_TYPE::CONSTANT && + ReturnType != DIFFE_TYPE::OUT_DIFF) { + RetTypes.push_back(getShadowType(FTy.getResult(i), width)); + } else { + RetTypes.push_back(FTy.getResult(i)); + } } } else if (returnValue == ReturnType::ArgsWithTwoReturns || returnValue == ReturnType::TwoReturns) { - assert(FTy.getNumResults() == 1); - RetTypes.push_back(FTy.getResult(0)); - if (ReturnType != DIFFE_TYPE::CONSTANT && - ReturnType != DIFFE_TYPE::OUT_DIFF) { - RetTypes.push_back(getShadowType(FTy.getResult(0), width)); - } else { - RetTypes.push_back(FTy.getResult(0)); + assert(FTy.getNumResults() >= 1); + for (size_t i = 0; i < FTy.getNumResults(); i++) { + RetTypes.push_back(FTy.getResult(i)); + if (ReturnType != DIFFE_TYPE::CONSTANT && + ReturnType != DIFFE_TYPE::OUT_DIFF) { + RetTypes.push_back(getShadowType(FTy.getResult(i), width)); + } else { + RetTypes.push_back(FTy.getResult(i)); + } } } @@ -200,7 +209,7 @@ FunctionOpInterface CloneFunctionWithReturns( SmallPtrSetImpl &constants, SmallPtrSetImpl &nonconstants, SmallPtrSetImpl &returnvals, ReturnType returnValue, - DIFFE_TYPE ReturnType, Twine name, IRMapping &VMap, + DIFFE_TYPE DReturnType, Twine name, IRMapping &VMap, std::map &OpMap, bool diffeReturnArg, mlir::Type additionalArg) { assert(!F.getFunctionBody().empty()); @@ -208,7 +217,7 @@ FunctionOpInterface CloneFunctionWithReturns( // llvm::ValueToValueMapTy VMap; auto FTy = getFunctionTypeForClone( F.getFunctionType().cast(), mode, width, - additionalArg, constant_args, diffeReturnArg, returnValue, ReturnType); + additionalArg, constant_args, diffeReturnArg, returnValue, DReturnType); /* for (Block &BB : F.getFunctionBody().getBlocks()) { @@ -235,6 +244,8 @@ FunctionOpInterface CloneFunctionWithReturns( { auto &blk = NewF.getFunctionBody().front(); + assert(F.getFunctionBody().front().getNumArguments() == + constant_args.size()); for (ssize_t i = constant_args.size() - 1; i >= 0; i--) { mlir::Value oval = F.getFunctionBody().front().getArgument(i); if (constant_args[i] == DIFFE_TYPE::CONSTANT) @@ -262,5 +273,104 @@ FunctionOpInterface CloneFunctionWithReturns( } } + std::string ToClone[] = { + "bufferization.writable", + "mhlo.sharding", + "mhlo.layout_mode", + "xla_framework.input_mapping", + "xla_framework.result_mapping", + }; + size_t newxlacnt = 0; + { + size_t oldi = 0; + size_t newi = 0; + while (oldi < F.getNumResults()) { + bool primalReturn = returnValue == ReturnType::ArgsWithReturn || + returnValue == ReturnType::ArgsWithTwoReturns || + (returnValue == ReturnType::TapeAndReturn && + DReturnType == DIFFE_TYPE::CONSTANT) || + returnValue == ReturnType::TapeAndTwoReturns || + returnValue == ReturnType::TwoReturns || + (returnValue == ReturnType::Return && + DReturnType == DIFFE_TYPE::CONSTANT); + if (primalReturn) { + for (auto attrName : ToClone) { + auto attrNameS = StringAttr::get(F->getContext(), attrName); + NewF.removeResultAttr(newi, attrNameS); + if (auto attr = F.getResultAttr(oldi, attrName)) { + if (attrName == "xla_framework.result_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setResultAttr(newi, attrNameS, attr); + } + } + newi++; + } + if (DReturnType == DIFFE_TYPE::DUP_ARG || + DReturnType == DIFFE_TYPE::DUP_NONEED) { + for (auto attrName : ToClone) { + auto attrNameS = StringAttr::get(F->getContext(), attrName); + NewF.removeResultAttr(newi, attrNameS); + if (auto attr = F.getResultAttr(oldi, attrName)) { + if (attrName == "xla_framework.result_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setResultAttr(newi, attrNameS, attr); + } + } + newi++; + } + oldi++; + } + } + { + size_t oldi = 0; + size_t newi = 0; + while (oldi < F.getNumArguments()) { + for (auto attrName : ToClone) { + NewF.removeArgAttr(newi, attrName); + if (auto attr = F.getArgAttr(oldi, attrName)) { + if (attrName == "xla_framework.input_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setArgAttr(newi, attrName, attr); + } + } + + newi++; + if (constant_args[oldi] == DIFFE_TYPE::DUP_ARG || + constant_args[oldi] == DIFFE_TYPE::DUP_NONEED) { + + for (auto attrName : ToClone) { + NewF.removeArgAttr(newi, attrName); + if (auto attr = F.getArgAttr(oldi, attrName)) { + if (attrName == "xla_framework.input_mapping") { + auto iattr = cast(attr); + APSInt nc(iattr.getValue()); + nc = newxlacnt; + attr = IntegerAttr::get(F->getContext(), nc); + newxlacnt++; + } + NewF.setArgAttr(newi, attrName, attr); + } + } + newi++; + } + oldi++; + } + } + return NewF; -} \ No newline at end of file +} diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index a92a5f3386ff..ead8baad9261 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -1,10 +1,10 @@ #include "Dialect/Ops.h" +#include "Implementations/CoreDialectsAutoDiffImplementations.h" #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionInterfaces.h" // TODO: this shouldn't depend on specific dialects except Enzyme. @@ -21,7 +21,7 @@ using namespace mlir; using namespace mlir::enzyme; -void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB, +void createTerminator(MGradientUtils *gutils, mlir::Block *oBB, DIFFE_TYPE retType, ReturnType retVal) { auto inst = oBB->getTerminator(); @@ -33,39 +33,7 @@ void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB, nBuilder.setInsertionPointToEnd(nBB); if (auto binst = dyn_cast(inst)) { - // TODO generalize to cloneWithNewBlockArgs interface - SmallVector newVals; - - SmallVector segSizes; - for (size_t i = 0, len = binst.getSuccessorOperands(0) - .getForwardedOperands() - .getBeginOperandIndex(); - i < len; i++) - newVals.push_back(gutils->getNewFromOriginal(binst->getOperand(i))); - segSizes.push_back(newVals.size()); - for (size_t i = 0; i < newInst->getNumSuccessors(); i++) { - size_t cur = newVals.size(); - for (auto op : binst.getSuccessorOperands(i).getForwardedOperands()) { - newVals.push_back(gutils->getNewFromOriginal(op)); - if (!gutils->isConstantValue(op)) { - newVals.push_back(gutils->invertPointerM(op, nBuilder)); - } - } - cur = newVals.size() - cur; - segSizes.push_back(cur); - } - - SmallVector attrs(newInst->getAttrs()); - for (auto &attr : attrs) { - if (attr.getName() == "operandSegmentSizes") - attr.setValue(nBuilder.getDenseI32ArrayAttr(segSizes)); - } - - nBB->push_back( - newInst->create(newInst->getLoc(), newInst->getName(), TypeRange(), - newVals, attrs, OpaqueProperties(nullptr), - newInst->getSuccessors(), newInst->getNumRegions())); - gutils->erase(newInst); + mlir::enzyme::detail::branchingForwardHandler(inst, nBuilder, gutils); return; } @@ -77,44 +45,52 @@ void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB, switch (retVal) { case ReturnType::Return: { - auto ret = inst->getOperand(0); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); - } else { - Type retTy = ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue(nBuilder, - ret.getLoc()); + for (size_t i = 0; i < inst->getNumOperands(); i++) { + auto ret = inst->getOperand(i); + + mlir::Value toret; + if (retType == DIFFE_TYPE::CONSTANT) { + toret = gutils->getNewFromOriginal(ret); + } else if (!isa(ret.getType()) && + true /*type analysis*/) { + toret = gutils->invertPointerM(ret, nBuilder); + } else if (!gutils->isConstantValue(ret)) { + toret = gutils->invertPointerM(ret, nBuilder); + } else { + Type retTy = + ret.getType().cast().getShadowType(); + toret = retTy.cast().createNullValue( + nBuilder, ret.getLoc()); + } + retargs.push_back(toret); } - retargs.push_back(toret); break; } case ReturnType::TwoReturns: { if (retType == DIFFE_TYPE::CONSTANT) assert(false && "Invalid return type"); - auto ret = inst->getOperand(0); - - retargs.push_back(gutils->getNewFromOriginal(ret)); - - mlir::Value toret; - if (retType == DIFFE_TYPE::CONSTANT) { - toret = gutils->getNewFromOriginal(ret); - } else if (!isa(ret.getType()) && true /*type analysis*/) { - toret = gutils->invertPointerM(ret, nBuilder); - } else if (!gutils->isConstantValue(ret)) { - toret = gutils->invertPointerM(ret, nBuilder); - } else { - Type retTy = ret.getType().cast().getShadowType(); - toret = retTy.cast().createNullValue(nBuilder, - ret.getLoc()); + for (size_t i = 0; i < inst->getNumOperands(); i++) { + auto ret = inst->getOperand(i); + + retargs.push_back(gutils->getNewFromOriginal(ret)); + + mlir::Value toret; + if (retType == DIFFE_TYPE::CONSTANT) { + toret = gutils->getNewFromOriginal(ret); + } else if (!isa(ret.getType()) && + true /*type analysis*/) { + toret = gutils->invertPointerM(ret, nBuilder); + } else if (!gutils->isConstantValue(ret)) { + toret = gutils->invertPointerM(ret, nBuilder); + } else { + Type retTy = + ret.getType().cast().getShadowType(); + toret = retTy.cast().createNullValue( + nBuilder, ret.getLoc()); + } + retargs.push_back(toret); } - retargs.push_back(toret); break; } case ReturnType::Void: { @@ -149,6 +125,9 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); } + assert(fn.getFunctionBody().front().getNumArguments() == constants.size()); + assert(fn.getFunctionBody().front().getNumArguments() == + volatile_args.size()); MForwardCacheKey tup = { fn, retType, constants, @@ -203,6 +182,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( unnecessaryInstructions, gutils, TLI); */ + bool valid = true; for (Block &oBB : gutils->oldFunc.getFunctionBody().getBlocks()) { // Don't create derivatives for code that results in termination if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) { @@ -218,14 +198,14 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( continue; } - auto term = oBB.getTerminator(); - assert(term); + assert(oBB.getTerminator()); auto first = oBB.begin(); auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end()); for (auto it = first; it != last; ++it) { // TODO: propagate errors. - (void)gutils->visitChild(&*it); + auto res = gutils->visitChild(&*it); + valid &= res.succeeded(); } createTerminator(gutils, &oBB, retType, returnValue); @@ -252,6 +232,9 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto nf = gutils->newFunc; delete gutils; + if (!valid) + return nullptr; + // if (PostOpt) // PPC.optimizeIntermediate(nf); // if (EnzymePrint) { diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index 223f604eaabb..56d49bf79b09 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -10,8 +10,7 @@ namespace mlir { namespace enzyme { -typedef void(buildReturnFunction)(OpBuilder &, Location, - SmallVector); +typedef void(buildReturnFunction)(OpBuilder &, mlir::Block *); class MGradientUtilsReverse; @@ -120,32 +119,27 @@ class MEnzymeLogic { std::vector constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, - std::vector volatile_args, void *augmented, - SymbolTableCollection &symbolTable); + std::vector volatile_args, void *augmented); void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); - void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, - llvm::function_ref buildReturnOp, - bool parentRegion); + void + handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, + MGradientUtilsReverse *gutils, + llvm::function_ref buildReturnOp); void visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); void visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); - bool visitChildCustom(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils); - void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, bool parentRegion); void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); SmallVector getDominatorToposort(MGradientUtilsReverse *gutils, Region ®ion); void differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, - Region &newRegion, bool parentRegion, + Region &newRegion, llvm::function_ref buildFuncRetrunOp, std::function(Type)> cacheCreator); }; } // Namespace enzyme -} // Namespace mlir \ No newline at end of file +} // Namespace mlir diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index b9cbfe3e6913..25e8f1818cd2 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -11,8 +11,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Dominance.h" -#include "llvm/ADT/BreadthFirstIterator.h" #include "EnzymeLogic.h" #include "Interfaces/GradientUtils.h" @@ -22,142 +20,19 @@ using namespace mlir; using namespace mlir::enzyme; -SmallVector -MEnzymeLogic::getDominatorToposort(MGradientUtilsReverse *gutils, - Region ®ion) { - SmallVector dominatorToposortBlocks; - if (region.hasOneBlock()) { - dominatorToposortBlocks.push_back(&*(region.begin())); - } else { - auto dInfo = mlir::detail::DominanceInfoBase(nullptr); - llvm::DominatorTreeBase &dt = - dInfo.getDomTree(&(gutils->oldFunc.getFunctionBody())); - auto root = dt.getNode(&*(region.begin())); - - for (llvm::DomTreeNodeBase *node : llvm::breadth_first(root)) { - dominatorToposortBlocks.push_back(node->getBlock()); - } - } - return dominatorToposortBlocks; -} - -void MEnzymeLogic::mapInvertArguments(Block *oBB, Block *reverseBB, - MGradientUtilsReverse *gutils) { - OpBuilder builder(reverseBB, reverseBB->begin()); - for (int i = 0; i < (int)gutils->mapBlockArguments[oBB].size(); i++) { - auto x = gutils->mapBlockArguments[oBB][i]; - if (auto iface = x.second.getType().dyn_cast()) { - Value added = reverseBB->getArgument(i); - if (gutils->hasInvertPointer(x.second)) { - added = iface.createAddOp(builder, x.second.getLoc(), added, - gutils->invertPointerM(x.second, builder)); - } - gutils->mapInvertPointer(x.second, added, builder); - } - } -} - -void MEnzymeLogic::handleReturns(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, - bool parentRegion) { +void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, + MGradientUtilsReverse *gutils) { if (oBB->getNumSuccessors() == 0) { - if (parentRegion) { - Operation *returnStatement = newBB->getTerminator(); - gutils->erase(returnStatement); - - OpBuilder forwardToBackwardBuilder(newBB, newBB->end()); - gutils->mapInvertPointer( - oBB->getTerminator()->getOperand(0), - gutils->newFunc.getArgument(gutils->newFunc.getNumArguments() - 1), - forwardToBackwardBuilder); // TODO handle multiple return values - Operation *newBranchOp = forwardToBackwardBuilder.create( - oBB->getTerminator()->getLoc(), reverseBB); - - gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp; - } else { - Operation *terminator = oBB->getTerminator(); - OpBuilder builder(reverseBB, reverseBB->begin()); - - int i = 0; - for (OpOperand &operand : terminator->getOpOperands()) { - Value val = operand.get(); - if (auto iface = val.getType().dyn_cast()) { - gutils->mapInvertPointer(val, reverseBB->getArgument(i), builder); - i++; - } - } - } - } -} - -bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) { - std::string nameDiffe = "diffe_" + op->getName().getDialectNamespace().str() + - "_" + op->getName().stripDialect().str(); - std::string nameStore = "store_" + op->getName().getDialectNamespace().str() + - "_" + op->getName().stripDialect().str(); - - StringRef srDiffe(nameDiffe); - StringRef srStore(nameStore); - - OperationName opNameDiffe(srDiffe, op->getContext()); - OperationName opNameStore(srStore, op->getContext()); - - Operation *symbolDiffe = gutils->symbolTable.lookupNearestSymbolFrom( - op, opNameDiffe.getIdentifier()); - Operation *symbolStore = gutils->symbolTable.lookupNearestSymbolFrom( - op, opNameStore.getIdentifier()); + Operation *returnStatement = newBB->getTerminator(); + gutils->erase(returnStatement); - if (symbolDiffe != nullptr) { - SmallVector caches; - if (symbolStore != nullptr) { - Operation *newOp = gutils->getNewFromOriginal(op); + OpBuilder forwardToBackwardBuilder(newBB, newBB->end()); - func::FuncOp funcStore = cast(symbolStore); + Operation *newBranchOp = forwardToBackwardBuilder.create( + oBB->getTerminator()->getLoc(), reverseBB); - SmallVector storeResultTypes; - for (auto x : funcStore.getFunctionType().getResults()) { - storeResultTypes.push_back(x); - } - - SmallVector storeArgs; - for (auto x : newOp->getOperands()) { - storeArgs.push_back(x); - } - - OpBuilder storeBuilder(newOp); - func::CallOp storeCI = storeBuilder.create( - op->getLoc(), srStore, storeResultTypes, storeArgs); - for (auto x : storeCI.getResults()) { - caches.push_back(gutils->initAndPushCache(x, storeBuilder)); - } - } - - SmallVector args; - for (Value opResult : op->getResults()) { - if (gutils->hasInvertPointer(opResult)) { - Value invertValue = gutils->invertPointerM(opResult, builder); - args.push_back(invertValue); - } - } - for (Value cache : caches) { - args.push_back(gutils->popCache(cache, builder)); - } - - SmallVector resultTypes; - for (auto x : op->getOperands()) { - resultTypes.push_back(x.getType()); - } - - func::CallOp dCI = - builder.create(op->getLoc(), srDiffe, resultTypes, args); - for (int i = 0; i < (int)op->getNumOperands(); i++) { - gutils->mapInvertPointer(op->getOperand(i), dCI.getResult(i), builder); - } - - return true; + gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp; } - return false; } /* @@ -165,16 +40,25 @@ Create reverse mode adjoint for an operation. */ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) { + if ((op->getBlock()->getTerminator() != op) && + llvm::all_of(op->getResults(), + [gutils](Value v) { return gutils->isConstantValue(v); }) && + gutils->isConstantInstruction(op)) { + return; + } if (auto ifaceOp = dyn_cast(op)) { SmallVector caches = ifaceOp.cacheValues(gutils); ifaceOp.createReverseModeAdjoint(builder, gutils, caches); - + return; + /* for (int indexResult = 0; indexResult < (int)op->getNumResults(); indexResult++) { Value result = op->getResult(indexResult); gutils->clearValue(result, builder); } + */ } + op->emitError() << "could not compute the adjoint for this operation " << *op; } void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, @@ -185,186 +69,114 @@ void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, auto last = oBB->rend(); for (auto it = first; it != last; ++it) { Operation *op = &*it; - bool customFound = visitChildCustom(op, revBuilder, gutils); - if (!customFound) { - visitChild(op, revBuilder, gutils); - } + visitChild(op, revBuilder, gutils); } } } void MEnzymeLogic::handlePredecessors( Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, - llvm::function_ref buildReturnOp, bool parentRegion) { + llvm::function_ref buildReturnOp) { OpBuilder revBuilder(reverseBB, reverseBB->end()); if (oBB->hasNoPredecessors()) { - SmallVector retargs; - // We need different handling on the top level due to - // the presence of duplicated args since we don't yet have activity analysis - if (parentRegion) { - assert(gutils->ArgDiffeTypes.size() == - gutils->oldFunc.getNumArguments() && - "Mismatch of activity array size vs # original function args"); - for (const auto &[diffeType, oldArg] : - llvm::zip(gutils->ArgDiffeTypes, oBB->getArguments())) { - if (diffeType == DIFFE_TYPE::OUT_DIFF) { - retargs.push_back(gutils->invertPointerM(oldArg, revBuilder)); - } - } - } else { - for (auto arg : oBB->getArguments()) { - if (gutils->hasInvertPointer(arg)) { - retargs.push_back(gutils->invertPointerM(arg, revBuilder)); - } - } - } - buildReturnOp(revBuilder, oBB->rbegin()->getLoc(), retargs); + buildReturnOp(revBuilder, oBB); } else { + Location loc = oBB->rbegin()->getLoc(); + // TODO remove dependency on CF dialect + + Value cache = gutils->insertInit(gutils->getIndexCacheType()); + + Value flag = + revBuilder.create(loc, gutils->getIndexType(), cache); + + Block *defaultBlock = nullptr; + SmallVector blocks; SmallVector indices; - SmallVector> arguments; - SmallVector defaultArguments; - Block *defaultBlock = nullptr; - for (auto pair : llvm::enumerate(oBB->getPredecessors())) { - auto predecessor = pair.value(); - auto idx = pair.index(); - Block *predecessorRevMode = - gutils->mapReverseModeBlocks.lookupOrNull(predecessor); - - SmallVector operands; - auto argumentsIt = gutils->mapBlockArguments.find(predecessor); - if (argumentsIt != gutils->mapBlockArguments.end()) { - for (auto operandOld : argumentsIt->second) { - if (oBB == operandOld.first.getParentBlock() && - gutils->hasInvertPointer(operandOld.first)) { - operands.push_back( - gutils->invertPointerM(operandOld.first, revBuilder)); - } else { - if (auto iface = operandOld.first.getType() - .dyn_cast()) { - Value nullValue = - iface.createNullValue(revBuilder, oBB->rbegin()->getLoc()); - operands.push_back(nullValue); - } else { - llvm_unreachable("no canonial null value found"); - } - } - } - } - if (idx != 0) { - blocks.push_back(predecessorRevMode); - indices.push_back(APInt(32, idx - 1)); - arguments.emplace_back(std::move(operands)); - } else { - defaultBlock = predecessorRevMode; - defaultArguments = operands; + + OpBuilder newBuilder(newBB, newBB->begin()); + + SmallVector diffes; + for (auto arg : oBB->getArguments()) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + diffes.push_back(gutils->diffe(arg, revBuilder)); + gutils->zeroDiffe(arg, revBuilder); + continue; } + diffes.push_back(nullptr); } - // Clear invert pointers of all arguments with gradient - for (auto argument : oBB->getArguments()) { - if (gutils->hasInvertPointer(argument)) { - auto iface = argument.getType().cast(); - Value nullValue = iface.createNullValue(revBuilder, argument.getLoc()); - gutils->mapInvertPointer(argument, nullValue, revBuilder); + for (auto [idx, pred] : llvm::enumerate(oBB->getPredecessors())) { + auto reversePred = gutils->mapReverseModeBlocks.lookupOrNull(pred); + + Block *newPred = gutils->getNewFromOriginal(pred); + + OpBuilder predecessorBuilder(newPred->getTerminator()); + + Value pred_idx_c = + predecessorBuilder.create(loc, idx - 1, 32); + predecessorBuilder.create(loc, cache, pred_idx_c); + + if (idx == 0) { + defaultBlock = reversePred; + + } else { + indices.push_back(APInt(32, idx - 1)); + blocks.push_back(reversePred); } - } - Location loc = oBB->rbegin()->getLoc(); - // Remove Dependency to CF dialect - if (std::next(oBB->getPredecessors().begin()) == - oBB->getPredecessors().end()) { - // If there is only one block we can directly create a branch for - // simplicity sake - revBuilder.create(loc, defaultBlock, defaultArguments); - } else { - Value cache = gutils->insertInit(gutils->getIndexCacheType()); - Value flag = - revBuilder.create(loc, gutils->getIndexType(), cache); - - SmallVector argumentRanges; - for (const auto &a : arguments) - argumentRanges.emplace_back(a); - revBuilder.create( - loc, flag, defaultBlock, defaultArguments, ArrayRef(indices), - ArrayRef(blocks), argumentRanges); - - Value origin = newBB->addArgument(gutils->getIndexType(), loc); - - OpBuilder newBuilder(newBB, newBB->begin()); - newBuilder.create(loc, cache, origin); - - int j = 0; - for (Block *predecessor : oBB->getPredecessors()) { - Block *newPredecessor = gutils->getNewFromOriginal(predecessor); - - OpBuilder predecessorBuilder(newPredecessor, - std::prev(newPredecessor->end())); - Value indicator = - predecessorBuilder.create(loc, j++, 32); - - Operation *terminator = newPredecessor->getTerminator(); - if (auto binst = dyn_cast(terminator)) { - for (unsigned i = 0; i < terminator->getNumSuccessors(); i++) { - if (terminator->getSuccessor(i) == newBB) { - SuccessorOperands sOps = binst.getSuccessorOperands(i); - sOps.append(indicator); + auto term = pred->getTerminator(); + if (auto iface = dyn_cast(term)) { + for (auto &op : term->getOpOperands()) + if (auto blk_idx = + iface.getSuccessorBlockArgument(op.getOperandNumber())) + if ((*blk_idx).getOwner() == oBB) { + auto idx = (*blk_idx).getArgNumber(); + if (diffes[idx]) { + + Value rev_idx_c = + revBuilder.create(loc, idx - 1, 32); + + auto to_prop = revBuilder.create( + loc, + revBuilder.create( + loc, arith::CmpIPredicate::eq, flag, rev_idx_c), + diffes[idx], + cast(diffes[idx].getType()) + .createNullValue(revBuilder, loc)); + + gutils->addToDiffe(op.get(), to_prop, revBuilder); + } } - } - } else { - llvm_unreachable("invalid terminator"); - } + } else { + assert(0 && "predecessor did not implement branch op interface"); } } - } -} -void MEnzymeLogic::initializeShadowValues( - SmallVector &dominatorToposortBlocks, - MGradientUtilsReverse *gutils) { - for (auto it = dominatorToposortBlocks.begin(); - it != dominatorToposortBlocks.end(); ++it) { - Block *oBB = *it; - - if (!oBB->empty()) { - for (auto it = oBB->begin(); it != oBB->end(); ++it) { - Operation *op = &*it; - Operation *newOp = gutils->getNewFromOriginal(op); - - if (auto ifaceOp = dyn_cast(op)) { - OpBuilder builder(newOp); - ifaceOp.createShadowValues(builder, gutils); - } - } - } + revBuilder.create( + loc, flag, defaultBlock, ArrayRef(), ArrayRef(indices), + ArrayRef(blocks), + SmallVector(indices.size(), ValueRange())); } } void MEnzymeLogic::differentiate( MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, - bool parentRegion, llvm::function_ref buildFuncReturnOp, std::function(Type)> cacheCreator) { gutils->registerCacheCreatorHook(cacheCreator); auto scope = llvm::make_scope_exit( [&]() { gutils->deregisterCacheCreatorHook(cacheCreator); }); - gutils->createReverseModeBlocks(oldRegion, newRegion, parentRegion); - - SmallVector dominatorToposortBlocks = - getDominatorToposort(gutils, oldRegion); - initializeShadowValues(dominatorToposortBlocks, gutils); - - for (auto it = dominatorToposortBlocks.rbegin(); - it != dominatorToposortBlocks.rend(); ++it) { - Block *oBB = *it; - Block *newBB = gutils->getNewFromOriginal(oBB); - Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB); - mapInvertArguments(oBB, reverseBB, gutils); - handleReturns(oBB, newBB, reverseBB, gutils, parentRegion); - visitChildren(oBB, reverseBB, gutils); - handlePredecessors(oBB, newBB, reverseBB, gutils, buildFuncReturnOp, - parentRegion); + gutils->createReverseModeBlocks(oldRegion, newRegion); + + for (auto &oBB : oldRegion) { + Block *newBB = gutils->getNewFromOriginal(&oBB); + Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB); + handleReturns(&oBB, newBB, reverseBB, gutils); + visitChildren(&oBB, reverseBB, gutils); + handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp); } } @@ -372,8 +184,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( FunctionOpInterface fn, DIFFE_TYPE retType, std::vector constants, MTypeAnalysis &TA, bool returnUsed, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented, - SymbolTableCollection &symbolTable) { + MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; @@ -383,18 +194,25 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( ReturnType returnValue = ReturnType::Args; MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( *this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true, - constants, returnValue, addedType, symbolTable); + constants, returnValue, addedType); Region &oldRegion = gutils->oldFunc.getFunctionBody(); Region &newRegion = gutils->newFunc.getFunctionBody(); - auto buildFuncReturnOp = [](OpBuilder &builder, Location loc, - SmallVector retargs) { - builder.create(loc, retargs); + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { + SmallVector retargs; + for (auto [arg, cv] : llvm::zip(oBB->getArguments(), constants)) { + if (cv == DIFFE_TYPE::OUT_DIFF) { + retargs.push_back(gutils->diffe(arg, builder)); + } + } + builder.create(oBB->rbegin()->getLoc(), retargs); return; }; - differentiate(gutils, oldRegion, newRegion, true, buildFuncReturnOp, nullptr); + gutils->forceAugmentedReturns(); + + differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr); auto nf = gutils->newFunc; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 286456b2d039..ebae44c9efa1 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -36,14 +36,13 @@ mlir::enzyme::MGradientUtils::MGradientUtils( ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp) - : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), TA(TA_), - TR(TR_), omp(omp), blocksNotForAnalysis(), + : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), + invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_), + originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - width(width), ArgDiffeTypes(ArgDiffeTypes_), - originalToNewFn(originalToNewFn_), - originalToNewFnOps(originalToNewFnOps_), - invertedPointers(invertedPointers_) { + TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), + RetDiffeTypes(1, ReturnActivity) { /* for (BasicBlock &BB : *oldFunc) { @@ -117,14 +116,14 @@ mlir::enzyme::MGradientUtils::getNewFromOriginal(mlir::Block *originst) const { Operation * mlir::enzyme::MGradientUtils::getNewFromOriginal(Operation *originst) const { + assert(originst); auto found = originalToNewFnOps.find(originst); if (found == originalToNewFnOps.end()) { llvm::errs() << oldFunc << "\n"; llvm::errs() << newFunc << "\n"; for (auto &pair : originalToNewFnOps) { llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n"; - // llvm::errs() << " map[" << pair.first << "] = " << pair.second << " - // -- " << *pair.first << " " << *pair.second << "\n"; + llvm::errs() << " map[" << *pair.first << "] = " << *pair.second << "\n"; } llvm::errs() << originst << " - " << *originst << "\n"; llvm_unreachable("Could not get new op from original"); @@ -156,7 +155,12 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, if (isConstantValue(v)) { if (auto iface = v.getType().dyn_cast()) { OpBuilder::InsertionGuard guard(Builder2); - Builder2.setInsertionPoint(getNewFromOriginal(v.getDefiningOp())); + if (auto op = v.getDefiningOp()) + Builder2.setInsertionPoint(getNewFromOriginal(op)); + else { + auto ba = cast(v); + Builder2.setInsertionPointToStart(getNewFromOriginal(ba.getOwner())); + } Value dv = iface.createNullValue(Builder2, v.getLoc()); invertedPointers.map(v, dv); return dv; @@ -167,6 +171,55 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, llvm_unreachable("could not invert pointer"); } +mlir::Value +mlir::enzyme::MDiffeGradientUtils::getDifferential(mlir::Value oval) { + auto found = differentials.lookupOrNull(oval); + if (found != nullptr) + return found; + + auto shadowty = getShadowType(oval.getType()); + OpBuilder builder(oval.getContext()); + builder.setInsertionPointToStart(initializationBlock); + + auto shadow = builder.create( + oval.getLoc(), enzyme::GradientType::get(oval.getContext(), shadowty)); + auto toset = cast(shadowty).createNullValue( + builder, oval.getLoc()); + builder.create(oval.getLoc(), shadow, toset); + + differentials.map(oval, shadow); + return shadow; +} + +void mlir::enzyme::MDiffeGradientUtils::setDiffe(mlir::Value oval, + mlir::Value toset, + OpBuilder &BuilderM) { + assert(!isConstantValue(oval)); + auto iface = oval.getType().cast(); + if (!iface.isMutable()) { + auto shadow = getDifferential(oval); + BuilderM.create(oval.getLoc(), shadow, toset); + } else { + MGradientUtils::setDiffe(oval, toset, BuilderM); + } +} + +void mlir::enzyme::MDiffeGradientUtils::zeroDiffe(mlir::Value oval, + OpBuilder &BuilderM) { + assert(!isConstantValue(oval)); + auto iface = getShadowType(oval.getType()).cast(); + assert(!iface.isMutable()); + setDiffe(oval, iface.createNullValue(BuilderM, oval.getLoc()), BuilderM); +} + +mlir::Value mlir::enzyme::MDiffeGradientUtils::diffe(mlir::Value oval, + OpBuilder &BuilderM) { + + auto shadow = getDifferential(oval); + return BuilderM.create(oval.getLoc(), + getShadowType(oval.getType()), shadow); +} + void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM) { /* @@ -223,90 +276,46 @@ void mlir::enzyme::MGradientUtils::forceAugmentedReturns() { if (isConstantValue(val)) continue; auto i = val.getArgNumber(); - mlir::Value dval; - if (i == blk->getArguments().size() - 1) - dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc()); - else - dval = nblk->insertArgument(nblk->args_begin() + i + 1, - getShadowType(val.getType()), val.getLoc()); - - invertedPointers.map(val, dval); + if (mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit || + cast(val.getType()).isMutable()) { + mlir::Value dval; + if (i == blk->getArguments().size() - 1) + dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc()); + else + dval = + nblk->insertArgument(nblk->args_begin() + i + 1, + getShadowType(val.getType()), val.getLoc()); + + invertedPointers.map(val, dval); + } } }); oldFunc.walk([&](Operation *inst) { if (inst == oldFunc) return; - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit) { - OpBuilder BuilderZ(getNewFromOriginal(inst)); - for (auto res : inst->getResults()) { - if (!isConstantValue(res)) { - mlir::Type antiTy = getShadowType(res.getType()); - auto anti = - BuilderZ.create(res.getLoc(), antiTy); - invertedPointers.map(res, anti); - } - } - return; - } - /* - - if (inst->getType()->isFPOrFPVectorTy()) - continue; //! op->getType()->isPointerTy() && - //! !op->getType()->isIntegerTy()) { - - if (!TR.query(inst)[{-1}].isPossiblePointer()) - continue; - - if (isa(inst)) { - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); - Type *antiTy = getShadowType(inst->getType()); - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, inst->getName() + "'il_phi"); - invertedPointers.insert(std::make_pair( - (const Value *)inst, InvertedPointerVH(this, anti))); - continue; - } - - if (!isa(inst)) { - continue; - } - - if (isa(inst)) { - continue; - } - - if (isConstantValue(inst)) { - continue; - } - - CallInst *op = cast(inst); - Function *called = op->getCalledFunction(); - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); - Type *antiTy = getShadowType(inst->getType()); - - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, op->getName() + "'ip_phi"); - invertedPointers.insert( - std::make_pair((const Value *)inst, InvertedPointerVH(this, anti))); + OpBuilder BuilderZ(getNewFromOriginal(inst)); + for (auto res : inst->getResults()) { + if (isConstantValue(res)) + continue; - if (called && isAllocationFunction(called->getName(), TLI)) { - anti->setName(op->getName() + "'mi"); + if (!(mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit || + cast(res.getType()).isMutable())) + continue; + mlir::Type antiTy = getShadowType(res.getType()); + auto anti = BuilderZ.create(res.getLoc(), antiTy); + invertedPointers.map(res, anti); } - */ }); } LogicalResult MGradientUtils::visitChild(Operation *op) { if (mode == DerivativeMode::ForwardMode) { - // In absence of a proper activity analysis, approximate it by treating any - // side effect-free operation producing constants as inactive. - // if (auto iface = dyn_cast(op)) { - if (llvm::all_of(op->getResults(), + if ((op->getBlock()->getTerminator() != op) && + llvm::all_of(op->getResults(), [this](Value v) { return isConstantValue(v); }) && /*iface.hasNoEffect()*/ activityAnalyzer->isConstantOperation(TR, op)) { return success(); @@ -318,5 +327,6 @@ LogicalResult MGradientUtils::visitChild(Operation *op) { return iface.createForwardModeTangent(builder, this); } } - return op->emitError() << "could not compute the adjoint for this operation"; + return op->emitError() << "could not compute the adjoint for this operation " + << *op; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 9b9509b3ec22..32a7fe068d07 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -38,6 +38,7 @@ class MGradientUtils { bool omp; unsigned width; + SmallVector RetDiffeTypes; ArrayRef ArgDiffeTypes; mlir::Value getNewFromOriginal(const mlir::Value originst) const; @@ -54,36 +55,64 @@ class MGradientUtils { std::map &originalToNewFnOps_, DerivativeMode mode, unsigned width, bool omp); void erase(Operation *op) { op->erase(); } + void replaceOrigOpWith(Operation *op, ValueRange vals) { + for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) { + originalToNewFn.map(res, rep); + } + auto newOp = getNewFromOriginal(op); + newOp->replaceAllUsesWith(vals); + originalToNewFnOps.erase(op); + } void eraseIfUnused(Operation *op, bool erase = true, bool check = true) { // TODO } bool isConstantInstruction(mlir::Operation *v) const; bool isConstantValue(mlir::Value v) const; mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2); - void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); void forceAugmentedReturns(); Operation *cloneWithNewOperands(OpBuilder &B, Operation *op); LogicalResult visitChild(Operation *op); + + void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder); + + mlir::Type getShadowType(mlir::Type T) { + auto iface = cast(T); + return iface.getShadowType(width); + } }; class MDiffeGradientUtils : public MGradientUtils { +protected: + IRMapping differentials; + + Block *initializationBlock; + public: + mlir::Value getDifferential(mlir::Value origv); + + void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder); + + void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder); + + mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder); + MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA, MTypeResults TR, IRMapping &invertedPointers_, const SmallPtrSetImpl &constantvalues_, - const SmallPtrSetImpl &returnvals_, + const SmallPtrSetImpl &activevals_, DIFFE_TYPE ActiveReturn, ArrayRef constant_values, IRMapping &origToNew_, std::map &origToNewOps_, DerivativeMode mode, unsigned width, bool omp) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, - constantvalues_, returnvals_, ActiveReturn, + constantvalues_, activevals_, ActiveReturn, constant_values, origToNew_, origToNewOps_, mode, width, - omp) {} + omp), + initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor static MDiffeGradientUtils * diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 02a3d3d956d6..b57fbe68b594 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -35,20 +35,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( const SmallPtrSetImpl &activevals_, DIFFE_TYPE ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_) - : newFunc(newFunc_), oldFunc(oldFunc_), Logic(Logic), mode(mode_), - originalToNewFn(originalToNewFn_), - originalToNewFnOps(originalToNewFnOps_), TA(TA_), width(width), - ArgDiffeTypes(ArgDiffeTypes_), symbolTable(symbolTable_) { - - initInitializationBlock(invertedPointers_, ArgDiffeTypes_); -} - -// for(auto x : v.getUsers()){x->dump();} DEBUG - -bool MGradientUtilsReverse::onlyUsedInParentBlock(Value v) { - return !v.isUsedOutsideOfBlock(v.getParentBlock()); -} + DerivativeMode mode_, unsigned width) + : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, + invertedPointers_, constantvalues_, activevals_, + ReturnActivity, ArgDiffeTypes_, originalToNewFn_, + originalToNewFnOps_, mode_, width, /*omp*/ false) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -108,71 +99,6 @@ Value MGradientUtilsReverse::popCache(Value cache, OpBuilder &builder) { cache); } -// Gradient -Type mlir::enzyme::MGradientUtilsReverse::getGradientType(Value v) { - Type valueType = v.getType(); - return GradientType::get(v.getContext(), valueType); -} - -Value mlir::enzyme::MGradientUtilsReverse::insertInitGradient( - mlir::Value v, OpBuilder &builder) { - Type gradientType = getGradientType(v); - OpBuilder initBuilder(initializationBlock, initializationBlock->begin()); - Value gradient = initBuilder.create(v.getLoc(), gradientType); - return gradient; -} - -// Shadow Gradient -Type mlir::enzyme::MGradientUtilsReverse::getShadowedGradientType(Value v) { - Type valueType = v.getType(); - return ShadowedGradientType::get(v.getContext(), valueType); -} - -Value mlir::enzyme::MGradientUtilsReverse::insertInitShadowedGradient( - mlir::Value v, OpBuilder &builder) { - Type gradientType = getShadowedGradientType(v); - OpBuilder initBuilder(initializationBlock, initializationBlock->begin()); - Value gradient = initBuilder.create(v.getLoc(), gradientType); - return gradient; -} - -Value mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - const mlir::Value originst) const { - if (!originalToNewFn.contains(originst)) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - llvm::errs() << originst << "\n"; - llvm_unreachable("Could not get new val from original"); - } - return originalToNewFn.lookupOrNull(originst); -} - -Block *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - mlir::Block *originst) const { - if (!originalToNewFn.contains(originst)) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - llvm::errs() << originst << "\n"; - llvm_unreachable("Could not get new blk from original"); - } - return originalToNewFn.lookupOrNull(originst); -} - -Operation *mlir::enzyme::MGradientUtilsReverse::getNewFromOriginal( - Operation *originst) const { - auto found = originalToNewFnOps.find(originst); - if (found == originalToNewFnOps.end()) { - llvm::errs() << oldFunc << "\n"; - llvm::errs() << newFunc << "\n"; - for (auto &pair : originalToNewFnOps) { - llvm::errs() << " map[" << pair.first << "] = " << pair.second << "\n"; - } - llvm::errs() << originst << " - " << *originst << "\n"; - llvm_unreachable("Could not get new op from original"); - } - return found->second; -} - Operation * mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, Operation *op) { @@ -182,223 +108,25 @@ mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, return B.clone(*op, map); } -bool mlir::enzyme::MGradientUtilsReverse::isConstantInstruction( - Operation *op) const { - return false; -} - -bool mlir::enzyme::MGradientUtilsReverse::isConstantValue(Value v) const { - if (isa(v.getType())) - return true; - if (isa(v.getType())) - return true; - - if (matchPattern(v, m_Constant())) - return true; - - // TODO - return false; -} - -bool mlir::enzyme::MGradientUtilsReverse::requiresShadow(Type t) { - if (auto iface = dyn_cast(t)) { - return iface.requiresShadow(); - } - return false; -} - void mlir::enzyme::MGradientUtilsReverse::addToDiffe(Value oldGradient, Value addedGradient, OpBuilder &builder) { - // TODO - Value gradient = addedGradient; - if (hasInvertPointer(oldGradient)) { - Value operandGradient = invertPointerM(oldGradient, builder); - auto iface = cast(addedGradient.getType()); - gradient = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, + assert(!isConstantValue(oldGradient)); + Value operandGradient = diffe(oldGradient, builder); + auto iface = cast(addedGradient.getType()); + auto added = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, addedGradient); - } - mapInvertPointer(oldGradient, gradient, builder); -} - -Value mlir::enzyme::MGradientUtilsReverse::diffe(Value v, OpBuilder &builder) { - return invertPointerM(v, builder); -} - -/* -The value v must have an invert pointer -*/ -Value mlir::enzyme::MGradientUtilsReverse::invertPointerM(Value v, - OpBuilder &builder) { - if (invertedPointersGlobal.contains(v)) { - Value gradient = invertedPointersGlobal.lookupOrNull(v); - Type type = gradient.getType(); - - if (GradientType gType = dyn_cast(type)) { - Value ret = builder.create(v.getLoc(), gType.getBasetype(), - gradient); - return ret; - } else { - llvm_unreachable("found invalid type"); - } - } else if (invertedPointersShadow.contains(v)) { - Value gradient = invertedPointersShadow.lookupOrNull(v); - Type type = gradient.getType(); - - if (ShadowedGradientType gType = - dyn_cast(type)) { - Value ret = builder.create(v.getLoc(), gType.getBasetype(), - gradient); - return ret; - } else { - llvm_unreachable("found invalid type"); - } - } - - llvm::errs() << " could not invert pointer v " << v << "\n"; - llvm_unreachable("could not invert pointer"); -} - -void mlir::enzyme::MGradientUtilsReverse::mapInvertPointer( - mlir::Value v, mlir::Value invertValue, OpBuilder &builder) { - if (!invertedPointersGlobal.contains(v)) { - Value g = insertInitGradient(v, builder); - invertedPointersGlobal.map(v, g); - } - Value gradient = invertedPointersGlobal.lookupOrNull(v); - builder.create(v.getLoc(), gradient, invertValue); -} - -Value mlir::enzyme::MGradientUtilsReverse::getShadowValue(mlir::Value v) { - return shadowValues.lookupOrNull(v); -} - -void mlir::enzyme::MGradientUtilsReverse::mapShadowValue(mlir::Value v, - mlir::Value shadow, - OpBuilder &builder) { - assert(!invertedPointersShadow.contains( - v)); // Shadow Values must only be mapped exactly once - - Value cache = insertInitShadowedGradient(v, builder); - invertedPointersShadow.map(v, cache); - - builder.create(v.getLoc(), cache, shadow); - - shadowValues.map(v, shadow); -} - -void mlir::enzyme::MGradientUtilsReverse::clearValue(mlir::Value v, - OpBuilder &builder) { - if (invertedPointersGlobal.contains(v)) { - if (!onlyUsedInParentBlock(v)) { // TODO is this necessary? - Value gradient = invertedPointersGlobal.lookupOrNull(v); - Type type = cast(gradient.getType()).getBasetype(); - if (auto iface = dyn_cast(type)) { - Value zero = iface.createNullValue(builder, v.getLoc()); - builder.create(v.getLoc(), gradient, zero); - } else { - llvm_unreachable( - "Type does not have an associated AutoDiffTypeInterface"); - } - } - } else if (invertedPointersShadow.contains(v)) { - Value gradient = invertedPointersShadow.lookupOrNull(v); - builder.create(v.getLoc(), gradient); - } -} - -bool mlir::enzyme::MGradientUtilsReverse::hasInvertPointer(mlir::Value v) { - return (invertedPointersGlobal.contains(v)) || - (invertedPointersShadow.contains(v)); -} - -void MGradientUtilsReverse::initInitializationBlock( - IRMapping invertedPointers_, ArrayRef argDiffeTypes) { - initializationBlock = &*(this->newFunc.getFunctionBody().begin()); - - OpBuilder initializationBuilder( - &*(this->newFunc.getFunctionBody().begin()), - this->newFunc.getFunctionBody().begin()->begin()); - - for (const auto &[val, diffe_type] : llvm::zip( - this->oldFunc.getFunctionBody().getArguments(), argDiffeTypes)) { - if (diffe_type != DIFFE_TYPE::OUT_DIFF) { - continue; - } - auto iface = dyn_cast(val.getType()); - if (!iface) { - llvm_unreachable( - "Type does not have an associated AutoDiffTypeInterface"); - } - Value zero = iface.createNullValue(initializationBuilder, val.getLoc()); - mapInvertPointer(val, zero, initializationBuilder); - } - for (auto const &x : invertedPointers_.getValueMap()) { - if (auto iface = dyn_cast(x.first.getType())) { - if (iface.requiresShadow()) { - mapShadowValue(x.first, x.second, - initializationBuilder); // This may create an unnecessary - // ShadowedGradient which could - // be avoidable TODO - } else { - mapInvertPointer(x.first, x.second, initializationBuilder); - } - } else { - llvm_unreachable("TODO not implemented"); - } - } + setDiffe(oldGradient, added, builder); } void MGradientUtilsReverse::createReverseModeBlocks(Region &oldFunc, - Region &newFunc, - bool isParentRegion) { + Region &newFunc) { for (auto it = oldFunc.getBlocks().rbegin(); it != oldFunc.getBlocks().rend(); ++it) { Block *block = &*it; Block *reverseBlock = new Block(); - - SmallVector> - reverseModeArguments; // Argument, Assigned value (2. is technically not - // necessary but simplifies code a lot) - - // Add reverse mode Arguments to Block - Operation *term = block->getTerminator(); - mlir::BranchOpInterface brOp = dyn_cast(term); - bool returnLike = term->hasTrait(); - if (brOp) { - for (int i = 0; i < (int)term->getNumSuccessors(); i++) { - SuccessorOperands sOps = brOp.getSuccessorOperands(i); - Block *successorBlock = term->getSuccessor(i); - - assert(successorBlock->getNumArguments() == sOps.size()); - for (int j = 0; j < (int)sOps.size(); j++) { - // Check if the argument needs a gradient - if (auto iface = successorBlock->getArgument(j) - .getType() - .dyn_cast()) { - reverseModeArguments.push_back(std::pair( - successorBlock->getArgument(j), sOps[j])); - } - } - } - for (auto it : reverseModeArguments) { - reverseBlock->addArgument(it.second.getType(), it.second.getLoc()); - } - - mapBlockArguments[block] = reverseModeArguments; - } else if (returnLike) { - if (!isParentRegion) { - for (OpOperand &operand : term->getOpOperands()) { - Value val = operand.get(); - if (auto iface = val.getType().dyn_cast()) { - reverseBlock->addArgument(val.getType(), val.getLoc()); - } - } - } - } - - mapReverseModeBlocks.map(block, reverseBlock); newFunc.getBlocks().insert(newFunc.end(), reverseBlock); + mapReverseModeBlocks.map(block, reverseBlock); } } @@ -406,8 +134,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg, - SymbolTableCollection &symbolTable_) { + ReturnType returnValue, mlir::Type additionalArg) { std::string prefix; switch (mode_) { @@ -439,8 +166,8 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( prefix + todiff.getName(), originalToNew, originalToNewOps, diffeReturnArg, additionalArg); - return new MGradientUtilsReverse( - Logic, newFunc, todiff, TA, invertedPointers, constant_values, - nonconstant_values, retType, constant_args, originalToNew, - originalToNewOps, mode_, width, symbolTable_); + return new MGradientUtilsReverse(Logic, newFunc, todiff, TA, invertedPointers, + constant_values, nonconstant_values, retType, + constant_args, originalToNew, + originalToNewOps, mode_, width); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index 474badb034c9..96e899939538 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -15,10 +15,12 @@ #include +#include "GradientUtils.h" + namespace mlir { namespace enzyme { -class MGradientUtilsReverse { +class MGradientUtilsReverse : public MDiffeGradientUtils { public: MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, @@ -29,60 +31,13 @@ class MGradientUtilsReverse { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width, - SymbolTableCollection &symbolTable_); - - // From CacheUtility - FunctionOpInterface newFunc; - FunctionOpInterface oldFunc; - - MEnzymeLogic &Logic; - bool AtomicAdd; - DerivativeMode mode; - IRMapping invertedPointersGlobal; - IRMapping invertedPointersShadow; - IRMapping shadowValues; - Block *initializationBlock; + DerivativeMode mode_, unsigned width); IRMapping mapReverseModeBlocks; - DenseMap>> mapBlockArguments; - - IRMapping originalToNewFn; - std::map originalToNewFnOps; - - MTypeAnalysis &TA; - - unsigned width; - ArrayRef ArgDiffeTypes; - - SymbolTableCollection &symbolTable; - - mlir::Value getNewFromOriginal(const mlir::Value originst) const; - mlir::Block *getNewFromOriginal(mlir::Block *originst) const; - Operation *getNewFromOriginal(Operation *originst) const; - - void erase(Operation *op) { op->erase(); } - void eraseIfUnused(Operation *op, bool erase = true, bool check = true) { - // TODO - } - bool isConstantValue(mlir::Value v) const; - bool isConstantInstruction(mlir::Operation *v) const; - bool hasInvertPointer(mlir::Value v); - mlir::Value invertPointerM(mlir::Value v, OpBuilder &builder); - mlir::Value diffe(mlir::Value v, OpBuilder &builder); void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder); - void mapInvertPointer(mlir::Value v, mlir::Value invertValue, - OpBuilder &builder); - - mlir::Value getShadowValue(mlir::Value v); - void mapShadowValue(mlir::Value v, mlir::Value invertValue, - OpBuilder &builder); - void clearValue(mlir::Value v, OpBuilder &builder); - - void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); Type getIndexType(); Value insertInit(Type t); @@ -98,35 +53,18 @@ class MGradientUtilsReverse { Type getIndexCacheType(); Value initAndPushCache(Value v, OpBuilder &builder); - // Gradient - Type getGradientType(Value t); - Value insertInitGradient(mlir::Value v, OpBuilder &builder); - - // ShadowedGradient - Type getShadowedGradientType(Value t); - Value insertInitShadowedGradient(mlir::Value v, OpBuilder &builder); - - bool requiresShadow(Type t); - - void initInitializationBlock(IRMapping invertedPointers_, - ArrayRef argDiffeTypes); - - bool onlyUsedInParentBlock(Value v); - Operation *cloneWithNewOperands(OpBuilder &B, Operation *op); Value popCache(Value cache, OpBuilder &builder); - void createReverseModeBlocks(Region &oldFunc, Region &newFunc, - bool isParentRegion = false); + void createReverseModeBlocks(Region &oldFunc, Region &newFunc); static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef constant_args, - ReturnType returnValue, mlir::Type additionalArg, - SymbolTableCollection &symbolTable_); + ReturnType returnValue, mlir::Type additionalArg); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp index 448dd67b8dd3..648509f58e52 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToIndexAndLoad.cpp @@ -11,18 +11,12 @@ // procedure to the MemRef dialect. //===----------------------------------------------------------------------===// -#include "Dialect/Dialect.h" #include "Dialect/Ops.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Rewrite/PatternApplicator.h" #include "llvm/Support/raw_ostream.h" #include "Interfaces/AutoDiffTypeInterface.h" @@ -51,9 +45,6 @@ SmallVector applyAffineMap(AffineMap aMap, SmallVector indices, struct AddToOpToIndexAndLoadPass : public enzyme::AddToOpToIndexAndLoadPassBase { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { auto loc = op->getLoc(); auto enzymeAdjoint = dyn_cast(op); @@ -94,7 +85,7 @@ struct AddToOpToIndexAndLoadPass // auto load = cacheBuilder.create(loc, inputs[i], map[i], // indices); auto store = cacheBuilder.create(loc, load, // inputs[i], map[i], indices); - ValueRange mapAppliedIndices = + SmallVector mapAppliedIndices = applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc); auto load = cacheBuilder.create(loc, outs[i], mapAppliedIndices); @@ -105,7 +96,7 @@ struct AddToOpToIndexAndLoadPass } for (int i = 0; i < retargs.size(); i++) { - ValueRange mapAppliedIndices = + SmallVector mapAppliedIndices = applyAffineMap(map[num_ins + i], indices, cacheBuilder, loc); auto load = cacheBuilder.create(loc, outs[i], mapAppliedIndices); diff --git a/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp b/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp index de2ebeba376d..01bcd6683a96 100644 --- a/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp +++ b/enzyme/Enzyme/MLIR/Passes/AddToOpToSplit.cpp @@ -105,9 +105,6 @@ void processGenericDuplication(Operation *op, OpBuilder &builder, Location loc, struct AddToOpToSplitPass : public enzyme::AddToOpToSplitPassBase { void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - getOperation()->walk([&](Operation *op) { auto enzymeAdjoint = dyn_cast(op); auto loc = op->getLoc(); diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 03c91683adff..90e9506d2bb1 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -5,10 +5,11 @@ add_mlir_doc(Passes EnzymePasses ./ -gen-pass-doc) add_mlir_dialect_library(MLIREnzymeTransforms EnzymeMLIRPass.cpp + EnzymeWrapPass.cpp PrintActivityAnalysis.cpp PrintAliasAnalysis.cpp EnzymeToMemRef.cpp - ShadowedGradientToCache.cpp + SimplifyMath.cpp AddToOpToIndexAndLoad.cpp AddToOpToSplit.cpp RemoveUnusedEnzymeOps.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index efe4b7d53f76..b7d33b6faedc 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -11,16 +11,13 @@ //===----------------------------------------------------------------------===// #include "Dialect/Ops.h" -#include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "PassDetails.h" #include "Passes/Passes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #define DEBUG_TYPE "enzyme" @@ -34,8 +31,19 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; + static DIFFE_TYPE mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { + DIFFE_TYPE retType = DIFFE_TYPE::CONSTANT; + if (fn.getNumResults() != 0) { + if (mode == DerivativeMode::ReverseModeCombined) + retType = DIFFE_TYPE::OUT_DIFF; + else + retType = DIFFE_TYPE::DUP_ARG; + } + return retType; + } + template - void HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { + LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { std::vector constants; SmallVector args; @@ -63,12 +71,11 @@ struct DifferentiatePass : public DifferentiatePassBase { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - DIFFE_TYPE retType = - fn.getNumResults() == 0 ? DIFFE_TYPE::CONSTANT : DIFFE_TYPE::DUP_ARG; + auto mode = DerivativeMode::ForwardMode; + DIFFE_TYPE retType = mode_from_fn(fn, mode); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); - auto mode = DerivativeMode::ForwardMode; bool freeMemory = true; size_t width = 1; @@ -83,16 +90,20 @@ struct DifferentiatePass : public DifferentiatePassBase { /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, /*augmented*/ nullptr); + if (!newFunc) + return failure(); OpBuilder builder(CI); auto dCI = builder.create(CI.getLoc(), newFunc.getName(), newFunc.getResultTypes(), args); CI.replaceAllUsesWith(dCI); CI->erase(); + return success(); } template - void HandleAutoDiffReverse(SymbolTableCollection &symbolTable, T CI) { + LogicalResult HandleAutoDiffReverse(SymbolTableCollection &symbolTable, + T CI) { std::vector constants; SmallVector args; @@ -117,19 +128,18 @@ struct DifferentiatePass : public DifferentiatePassBase { truei++; } - // Add the return gradient - mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; - args.push_back(res); - auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - DIFFE_TYPE retType = - fn.getNumResults() == 0 ? DIFFE_TYPE::CONSTANT : DIFFE_TYPE::DUP_ARG; + auto mode = DerivativeMode::ReverseModeCombined; + DIFFE_TYPE retType = mode_from_fn(fn, mode); + + // Add the return gradient + mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; + args.push_back(res); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); - auto mode = DerivativeMode::ReverseModeGradient; bool freeMemory = true; size_t width = 1; @@ -143,13 +153,16 @@ struct DifferentiatePass : public DifferentiatePassBase { fn, retType, constants, TA, /*should return*/ false, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr, symbolTable); + /*augmented*/ nullptr); + if (!newFunc) + return failure(); OpBuilder builder(CI); auto dCI = builder.create(CI.getLoc(), newFunc.getName(), newFunc.getResultTypes(), args); CI.replaceAllUsesWith(dCI); CI->erase(); + return success(); } void lowerEnzymeCalls(SymbolTableCollection &symbolTable, @@ -167,7 +180,11 @@ struct DifferentiatePass : public DifferentiatePassBase { for (auto T : toLower) { if (auto F = dyn_cast(T)) { - HandleAutoDiff(symbolTable, F); + auto res = HandleAutoDiff(symbolTable, F); + if (!res.succeeded()) { + signalPassFailure(); + return; + } } else { llvm_unreachable("Illegal type"); } @@ -187,7 +204,11 @@ struct DifferentiatePass : public DifferentiatePassBase { for (auto T : toLower) { if (auto F = dyn_cast(T)) { - HandleAutoDiffReverse(symbolTable, F); + auto res = HandleAutoDiffReverse(symbolTable, F); + if (!res.succeeded()) { + signalPassFailure(); + return; + } } else { llvm_unreachable("Illegal type"); } @@ -201,19 +222,14 @@ struct DifferentiatePass : public DifferentiatePassBase { namespace mlir { namespace enzyme { std::unique_ptr createDifferentiatePass() { - new DifferentiatePass(); return std::make_unique(); } } // namespace enzyme } // namespace mlir -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" - void DifferentiatePass::runOnOperation() { SymbolTableCollection symbolTable; symbolTable.getSymbolTable(getOperation()); - ConversionPatternRewriter B(getOperation()->getContext()); getOperation()->walk( [&](FunctionOpInterface op) { lowerEnzymeCalls(symbolTable, op); }); } diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp new file mode 100644 index 000000000000..b48705c220d1 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -0,0 +1,134 @@ +//===- EnzymeWrapPass.cpp - Replace calls with their derivatives ------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to create wrapper functions which differentiate +// ops. +//===----------------------------------------------------------------------===// + +#include "Dialect/Ops.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "PassDetails.h" +#include "Passes/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace enzyme; + +namespace { +struct DifferentiateWrapperPass + : public DifferentiateWrapperPassBase { + + void runOnOperation() override { + MEnzymeLogic Logic; + SymbolTableCollection symbolTable; + symbolTable.getSymbolTable(getOperation()); + + Operation *symbolOp = nullptr; + if (infn != "") + symbolOp = symbolTable.lookupSymbolIn( + getOperation(), StringAttr::get(getOperation()->getContext(), infn)); + else { + for (auto &op : getOperation()->getRegion(0).front()) { + auto fn = dyn_cast(symbolOp); + if (!fn) + continue; + assert(symbolOp == nullptr); + symbolOp = &op; + } + } + auto fn = cast(symbolOp); + SmallVector split; + StringRef(argTys.getValue().data(), argTys.getValue().size()) + .split(split, ','); + std::vector constants; + for (auto &str : split) { + if (str == "enzyme_dup") + constants.push_back(DIFFE_TYPE::DUP_ARG); + else if (str == "enzyme_const") + constants.push_back(DIFFE_TYPE::CONSTANT); + else if (str == "enzyme_dupnoneed") + constants.push_back(DIFFE_TYPE::DUP_NONEED); + else if (str == "enzyme_out") + constants.push_back(DIFFE_TYPE::OUT_DIFF); + else { + llvm::errs() << "unknown argument activity to parse, found: '" << str + << "'\n"; + assert(0 && " unknown constant"); + } + } + + if (constants.size() != fn.getFunctionBody().front().getNumArguments()) { + fn->emitError() + << "Incorrect number of arg activity states for function, found " + << split; + return; + } + + DIFFE_TYPE retType = retTy.getValue(); + MTypeAnalysis TA; + auto type_args = TA.getAnalyzedTypeInfo(fn); + + bool freeMemory = true; + size_t width = 1; + + std::vector volatile_args; + for (auto &a : fn.getFunctionBody().getArguments()) { + (void)a; + volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined)); + } + + FunctionOpInterface newFunc; + if (mode == DerivativeMode::ForwardMode) { + newFunc = Logic.CreateForwardDiff( + fn, retType, constants, TA, + /*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory, + width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + } else { + newFunc = Logic.CreateReverseDiff( + fn, retType, constants, TA, + /*should return*/ false, mode, freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr); + } + if (!newFunc) { + signalPassFailure(); + return; + } + if (outfn == "") { + fn->erase(); + SymbolTable::setSymbolVisibility(newFunc, + SymbolTable::Visibility::Public); + SymbolTable::setSymbolName(cast(newFunc), + (std::string)infn); + } else { + SymbolTable::setSymbolName(cast(newFunc), + (std::string)outfn); + } + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createDifferentiateWrapperPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index d5e821bac643..c4eff488c332 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -8,9 +8,13 @@ #ifndef ENZYME_PASSES_H #define ENZYME_PASSES_H +#include "../../Utils.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Pass/Pass.h" #include + +#include "Dialect/Dialect.h" + namespace mlir { class PatternRewriter; class RewritePatternSet; @@ -18,13 +22,15 @@ class DominanceInfo; namespace enzyme { std::unique_ptr createDifferentiatePass(); +std::unique_ptr createDifferentiateWrapperPass(); + std::unique_ptr createPrintActivityAnalysisPass(); std::unique_ptr createPrintAliasAnalysisPass(); std::unique_ptr createEnzymeToMemRefPass(); -std::unique_ptr createShadowedGradientToCachePass(); +std::unique_ptr createMathematicSimplificationPass(); std::unique_ptr createAddToOpToIndexAndLoadPass(); @@ -59,12 +65,15 @@ class MemRefDialect; namespace func { class FuncDialect; -} +} // end namespace func +namespace affine { class AffineDialect; +} // end namespace affine + namespace LLVM { class LLVMDialect; -} +} // end namespace LLVM #define GEN_PASS_REGISTRATION #include "Passes/Passes.h.inc" diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 1289ddca9bf4..c1617eaf4af0 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -19,6 +19,64 @@ def DifferentiatePass : Pass<"enzyme"> { let constructor = "mlir::enzyme::createDifferentiatePass()"; } +def DifferentiateWrapperPass : Pass<"enzyme-wrap"> { + let summary = "Add wrapper function to be differentiated"; + let dependentDialects = [ + "cf::ControlFlowDialect", + "enzyme::EnzymeDialect" + ]; + let constructor = "mlir::enzyme::createDifferentiateWrapperPass()"; + let options = [ + Option< + /*C++ variable name=*/"infn", + /*CLI argument=*/"infn", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Input function to differentiate" + >, + Option< + /*C++ variable name=*/"outfn", + /*CLI argument=*/"outfn", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Output function to differentiate" + >, + Option< + /*C++ variable name=*/"mode", + /*CLI argument=*/"mode", + /*type=*/"DerivativeMode", + /*default=*/"DerivativeMode::ForwardMode", + /*description=*/"mode to differentiate", +[{::llvm::cl::values( + clEnumValN(DerivativeMode::ForwardMode, "ForwardMode", "ForwardMode (default)"), + clEnumValN(DerivativeMode::ReverseModeCombined, "ReverseModeCombined", "Combined ReverseMode"), + clEnumValN(DerivativeMode::ReverseModePrimal, "ReverseModePrimal", "Forward Pass of ReverseMode"), + clEnumValN(DerivativeMode::ReverseModeGradient, "ReverseModeGradient", "Backward Pass of ReverseMode") + )}] + >, + Option< + /*C++ variable name=*/"retTy", + /*CLI argument=*/"retTy", + /*type=*/"DIFFE_TYPE", + /*default=*/"DIFFE_TYPE::DUP_ARG", + /*description=*/"activity of the return", +[{::llvm::cl::values( + clEnumValN(DIFFE_TYPE::DUP_ARG, "enzyme_dup", "Duplicated (default)"), + clEnumValN(DIFFE_TYPE::OUT_DIFF, "enzyme_out", "Active"), + clEnumValN(DIFFE_TYPE::CONSTANT, "enzyme_const", "Constant"), + clEnumValN(DIFFE_TYPE::DUP_NONEED, "enzyme_dupnoneed", "Duplicated noneed") + )}] + >, + Option< + /*C++ variable name=*/"argTys", + /*CLI argument=*/"argTys", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"The activity of the arguments" + >, + ]; +} + def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> { let summary = "Print the results of activity analysis"; let constructor = "mlir::enzyme::createPrintActivityAnalysisPass()"; @@ -43,6 +101,13 @@ def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> { /*default=*/"false", /*description=*/"Annotate every operation and value with its activity" >, + Option< + /*C++ variable name=*/"dataflow", + /*CLI argument=*/"dataflow", + /*type=*/"bool", + /*default=*/"true", + /*description=*/"Whether to use the new Dataflow activity analysis" + >, Option< /*C++ variable name=*/"inactiveArgs", /*CLI argument=*/"inactive-args", @@ -76,9 +141,9 @@ def EnzymeOpsToMemRefPass : Pass<"convert-enzyme-to-memref"> { let constructor = "mlir::enzyme::createEnzymeToMemRefPass()"; } -def ShadowedGradientToCachePass : Pass<"convert-enzyme-shadowed-gradient-to-cache"> { - let summary = "Convert Enzyme Shadowed Gradient to Cache Ops"; - let constructor = "mlir::enzyme::createShadowedGradientToCachePass()"; +def MathematicSimplificationPass : Pass<"enzyme-simplify-math"> { + let summary = "Simplify basic mathematical operations"; + let constructor = "mlir::enzyme::createMathematicSimplificationPass()"; } def AddToOpToIndexAndLoadPass : Pass<"add-to-op-to-index-and-load"> { diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index fb3ef000a0c1..ac88d0fc86fa 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -10,8 +10,10 @@ // analysis. // //===----------------------------------------------------------------------===// +#include "Analysis/ActivityAnalysis.h" #include "Analysis/DataFlowActivityAnalysis.h" #include "Dialect/Ops.h" +#include "Interfaces/EnzymeLogic.h" #include "Passes/PassDetails.h" #include "Passes/Passes.h" @@ -116,10 +118,79 @@ struct PrintActivityAnalysisPass } } + void runActivityAnalysis(bool dataflow, FunctionOpInterface callee, + ArrayRef argActivities, + ArrayRef resultActivities, + bool print, bool verbose, bool annotate) { + if (dataflow) { + enzyme::runDataFlowActivityAnalysis(callee, argActivities, + /*print=*/true, verbose, annotate); + } else { + + SmallPtrSet blocksNotForAnalysis; + + mlir::enzyme::MTypeResults TR; // TODO + SmallPtrSet constant_values; + SmallPtrSet activevals_; + for (auto &&[arg, act] : + llvm::zip(callee.getFunctionBody().getArguments(), argActivities)) { + if (act == enzyme::Activity::enzyme_const) + constant_values.insert(arg); + else + activevals_.insert(arg); + } + auto ReturnActivity = DIFFE_TYPE::CONSTANT; + for (auto act : resultActivities) + if (act != enzyme::Activity::enzyme_const) + ReturnActivity = DIFFE_TYPE::DUP_ARG; + + enzyme::ActivityAnalyzer activityAnalyzer( + blocksNotForAnalysis, constant_values, activevals_, ReturnActivity); + + callee.walk([&](Operation *op) { + + }); + MLIRContext *ctx = callee.getContext(); + callee.walk([&](Operation *op) { + if (print) + llvm::outs() << " Operation: " << *op << "\n"; + for (auto ® : op->getRegions()) { + for (auto &blk : reg.getBlocks()) { + for (auto &arg : blk.getArguments()) { + bool icv = activityAnalyzer.isConstantValue(TR, arg); + if (annotate) + op->setAttr("enzyme.arg_icv" + + std::to_string(arg.getArgNumber()), + BoolAttr::get(ctx, icv)); + if (print) + llvm::outs() << " arg: " << arg << " icv=" << icv << "\n"; + } + } + } + + bool ici = activityAnalyzer.isConstantOperation(TR, op); + if (annotate) + op->setAttr("enzyme.ici", BoolAttr::get(ctx, ici)); + if (print) + llvm::outs() << " op ici=" << ici << "\n"; + + for (auto res : op->getResults()) { + bool icv = activityAnalyzer.isConstantValue(TR, res); + if (annotate) + op->setAttr("enzyme.res_icv" + + std::to_string(res.getResultNumber()), + BoolAttr::get(ctx, icv)); + if (print) + llvm::outs() << " res: " << res << " icv=" << icv << "\n"; + } + }); + } + } + void runOnOperation() override { auto moduleOp = cast(getOperation()); - if (annotate) { + if (annotate && dataflow) { // Infer the activity attributes from the __enzyme_autodiff call Operation *autodiff_decl = moduleOp.lookupSymbol("__enzyme_autodiff"); if (!autodiff_decl) @@ -148,8 +219,8 @@ struct PrintActivityAnalysisPass // supplied annotation. First argument is the callee inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); } return; } @@ -163,8 +234,8 @@ struct PrintActivityAnalysisPass resultActivities{callee.getNumResults()}; initializeArgAndResActivities(callee, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); }); return; } @@ -186,8 +257,8 @@ struct PrintActivityAnalysisPass resultActivities{callee.getNumResults()}; initializeArgAndResActivities(callee, argActivities, resultActivities); - enzyme::runDataFlowActivityAnalysis(callee, argActivities, - /*print=*/true, verbose, annotate); + runActivityAnalysis(dataflow, callee, argActivities, resultActivities, + /*print=*/true, verbose, annotate); } } }; diff --git a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp index 995e0f600949..d342f84fec70 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp @@ -91,7 +91,7 @@ struct PrintAliasAnalysisPass continue; // TODO(zinenko): this has been overriding the argument... // Use an array attr instead (will break syntactic tests). - state->getAliasClassesObject().foreachClass( + (void)state->getAliasClassesObject().foreachClass( [&](DistinctAttr aliasClass, enzyme::AliasClassSet::State state) { if (state == enzyme::AliasClassSet::State::Undefined) funcOp.setArgAttr( diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index e93334164dd0..5ced93e86512 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -22,168 +22,284 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/IR/Dominance.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace enzyme; -using llvm::errs; namespace { -// TODO: Expand to region branches?? -bool reachable(Operation *a, Operation *b) { - Block *aBlock = a->getBlock(); - Block *bBlock = b->getBlock(); - if (aBlock == bBlock) { - if (a->isBeforeInBlock(b)) { - return true; - } - } +// Starting at the beginning of blk, is there a path that can execute +// check before end. +bool mayExecuteBefore(Block *blk, Operation *check, Operation *end) { + auto reg = blk->getParent(); + assert(reg->isAncestor(end->getParentRegion())); + DenseSet visitedBlocks; + SmallVector blocksToVisit; + for (auto succ : blk->getSuccessors()) { + blocksToVisit.push_back(succ); + } - blocksToVisit.push_back(aBlock); while (!blocksToVisit.empty()) { - Block *processedBlock = blocksToVisit[blocksToVisit.size() - 1]; - blocksToVisit.pop_back(); + Block *cur = blocksToVisit.pop_back_val(); + + if (visitedBlocks.contains(cur)) + continue; + + visitedBlocks.insert(cur); + + bool seenEnd = false; + for (auto &op : *cur) { + + // If we've seen the thing to check with, it may execute before + if (op.isAncestor(check)) { + // The sole exception to this is if they are in the same sub region, + // which is known to execute only once. TODO this later + /* + if (op.isAncestor(end)) { + + for (auto reg2 : op.getRegions()) { + + } + } + */ - for (Block *successor : processedBlock->getSuccessors()) { - if (!visitedBlocks.contains(successor)) { - visitedBlocks.insert(successor); - blocksToVisit.push_back(successor); + return true; + } - if (successor == bBlock) - return true; + // Otherwise if we've seen the end op, this path is over as the route we + // found here didn't first find a check. + if (op.isAncestor(end)) { + seenEnd = true; + break; } } + + if (seenEnd) + continue; + + // If we didn't find the end, try all successors + for (auto succ : cur->getSuccessors()) { + blocksToVisit.push_back(succ); + } } + return false; } -template -Operation *findNearestDominatingOpByUse(Operation *op, Value v) { +bool mayExecuteBetween(Operation *start, Operation *check, Operation *end) { + + for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) { + // This check op has been found after start in its block + if (op->isAncestor(check)) { + return true; + } + if (op->isAncestor(end)) { + return false; + } + } + + Block *blk = start->getBlock(); + + auto reg = blk->getParent(); + if (reg->isAncestor(end->getParentRegion())) { + return mayExecuteBefore(blk, check, end); + } + + // If the check is in the parent op, but the end is not, assume + // we may execute that parent op part before going to any later ops + if (reg->isAncestor(check->getParentRegion())) { + return true; + } + + return mayExecuteBetween(start->getParentOp(), check, end); +} + +// TODO this isn't necessarily correct. This is because there could be a +// non dominating use bewteen the dominating one and the op, causing +// correctness issues when not seen. In interim, be conservative and only +// succeed if these have the same parent block, and no other ops in path +template +T findNearestDominatingOpByUse(Operation *op, Value v) { DominanceInfo dInfo; + PostDominanceInfo pdInfo; - Operation *closestSetOp = nullptr; + SmallVector options; + SmallVector conflicts; for (Operation *userSet : v.getUsers()) { if (auto setOp = dyn_cast(userSet)) { - if (dInfo.dominates(userSet, op)) { - if (closestSetOp == nullptr) { - closestSetOp = userSet; - } else if (dInfo.dominates(closestSetOp, userSet)) { - closestSetOp = userSet; - } + options.push_back(setOp); + conflicts.push_back(setOp); + continue; + } + if (auto setOp = dyn_cast(userSet)) { + conflicts.push_back(setOp); + continue; + } + } + + for (auto opt : options) { + if (!dInfo.dominates(opt, op)) + continue; + bool conflict = false; + for (auto opt2 : conflicts) { + if (opt == opt2) + continue; + if (opt2 == op) + continue; + + if (!mayExecuteBetween(opt, opt2, op)) { + continue; } + + conflict = true; + } + if (!conflict) { + return opt; } } - return closestSetOp; + + return nullptr; } +struct PopSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PopOp pop, + PatternRewriter &rewriter) const final { + + auto init = pop.getCache().getDefiningOp(); + if (!init) + return failure(); + + SmallVector pops; + SmallVector pushes; + for (Operation *userSet : init.getResult().getUsers()) { + if (auto push = dyn_cast(userSet)) { + pushes.push_back(push); + continue; + } + if (auto pop = dyn_cast(userSet)) { + pops.push_back(pop); + continue; + } + return failure(); + } + + if (auto push = findNearestDominatingOpByUse( + pop, init)) { + // Do the block check to conservatively avoid multi execute push/pop + if (pop->getBlock() == push->getBlock()) { + rewriter.replaceOp(pop, push.getValue()); + rewriter.eraseOp(push); + return success(); + } + } + + return failure(); + } +}; + +struct GetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::GetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + if (isa(userSet)) + continue; + return failure(); + } + + if (auto set = findNearestDominatingOpByUse(get, init)) { + rewriter.replaceOp(get, set.getValue()); + return success(); + } + return failure(); + } +}; + +struct SetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::SetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct PushSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PushOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getCache().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct InitSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::InitOp get, + PatternRewriter &rewriter) const final { + + if (get.use_empty()) { + rewriter.eraseOp(get); + return success(); + } + return failure(); + } +}; + struct RemoveUnusedEnzymeOpsPass : public enzyme::RemoveUnusedEnzymeOpsPassBase { void runOnOperation() override { - getOperation()->walk([&](Operation *op) { - DominanceInfo dInfo; - if (auto initOp = dyn_cast(op)) { - Value v = initOp; - if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userSet : v.getUsers()) { - if (auto setOp = dyn_cast(userSet)) { - for (Operation *userGet : v.getUsers()) { - if (auto getOp = dyn_cast(userGet)) { - // We can safely delete an enzyme.gradient op if each pair of - // enzyme.set and enzyme.get ops are either not reachable or - // are reachable and do not exist inside a loop - bool relatedButNotInLoop = - dInfo.dominates(userSet, userGet) && - !reachable(getOp, setOp); - bool unrelated = !reachable(setOp, getOp); - if (!(relatedButNotInLoop || unrelated)) { - replaceable = false; - } - } - } - } - } - if (replaceable) { - // Do replacing - for (Operation *userGet : v.getUsers()) { - if (auto getOp = dyn_cast(userGet)) { - Operation *closestSetOp = - findNearestDominatingOpByUse(userGet, v); - auto setOp = dyn_cast(closestSetOp); - getOp.replaceAllUsesWith(setOp.getValue()); - } - } - for (Operation *userGet : v.getUsers()) { - userGet->erase(); - } - op->erase(); - } - } else if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userPush : v.getUsers()) { - if (auto pushOp = dyn_cast(userPush)) { - // There should only be exactly one push per pop - if (reachable(userPush, userPush)) { - replaceable = false; - } - int numAssociatedPops = 0; - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Pops always need to be dominated by the push - if (dInfo.dominates(userPush, user)) { - numAssociatedPops++; - } else { - replaceable = false; - } - } - } - if (auto getOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Gets always need to be dominated by the push - if (!dInfo.dominates(userPush, user)) { - replaceable = false; - } - } - } - } - // There should only be one pop per push - if (numAssociatedPops > 1) { - replaceable = false; - } - } - } - if (replaceable) { - // Do replacing - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - popOp.replaceAllUsesWith(pushOp.getValue()); - } - if (auto getOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - getOp.replaceAllUsesWith(pushOp.getValue()); - } - } - for (Operation *user : v.getUsers()) { - user->erase(); - } - op->erase(); - } - } - } - }); - }; + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } }; + } // end anonymous namespace namespace mlir { diff --git a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp b/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp deleted file mode 100644 index b8a42f339280..000000000000 --- a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp +++ /dev/null @@ -1,74 +0,0 @@ -//===- ShadowedGradientToCache.cpp - Lower Shadowed Gradient ops -//------------------ // -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a pass to lower custom ops generated by the Enzyme AD -// procedure to the MemRef dialect. -//===----------------------------------------------------------------------===// - -#include "Dialect/Dialect.h" -#include "Dialect/Ops.h" -#include "PassDetails.h" -#include "Passes/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "mlir/Rewrite/PatternApplicator.h" - -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace enzyme; -using llvm::errs; -namespace { -struct ShadowedGradientToCachePass - : public enzyme::ShadowedGradientToCachePassBase< - ShadowedGradientToCachePass> { - void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionPatternRewriter rewriter(context); - - getOperation()->walk([&](Operation *op) { - if (auto initOp = dyn_cast(op)) { - if (auto type = - dyn_cast(initOp.getType())) { - Type cacheType = CacheType::get(op->getContext(), type.getBasetype()); - - OpBuilder builder(op); - Value buffer = builder.create(op->getLoc(), cacheType); - - initOp.replaceAllUsesWith(buffer); - initOp->erase(); - } - } - if (auto clearOp = dyn_cast(op)) { - if (auto type = - dyn_cast(clearOp.getCache().getType())) { - OpBuilder builder(op); - builder.create(op->getLoc(), type.getType(), - clearOp.getCache()); - - clearOp->erase(); - } - } - }); - }; -}; -} // end anonymous namespace - -namespace mlir { -namespace enzyme { -std::unique_ptr createShadowedGradientToCachePass() { - return std::make_unique(); -} -} // namespace enzyme -} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp new file mode 100644 index 000000000000..de04163fea8f --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp @@ -0,0 +1,88 @@ +//===- SimpliyMath.cpp - Simplify Mathematical operations ------------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower custom ops generated by the Enzyme AD +// procedure to the MemRef dialect. +//===----------------------------------------------------------------------===// + +#include "Dialect/Ops.h" +#include "PassDetails.h" +#include "Passes/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace enzyme; +using llvm::errs; +namespace { + +struct AddSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getRhs()); + return success(); + } + + if (matchPattern(op.getRhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + + return failure(); + } +}; + +struct SubSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SubFOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getRhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOpWithNewOp(op, op.getRhs()); + return success(); + } + + return failure(); + } +}; + +struct MathematicSimplification + : public enzyme::MathematicSimplificationPassBase< + MathematicSimplification> { + void runOnOperation() override { + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + }; +}; +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createMathematicSimplificationPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 589ffa610a28..4a7b9231d1ee 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -97,12 +97,7 @@ int main(int argc, char **argv) { }); // Register the autodiff interface implementations for upstream dialects. - enzyme::registerArithDialectAutoDiffInterface(registry); - enzyme::registerBuiltinDialectAutoDiffInterface(registry); - enzyme::registerLLVMDialectAutoDiffInterface(registry); - enzyme::registerMemRefDialectAutoDiffInterface(registry); - enzyme::registerSCFDialectAutoDiffInterface(registry); - enzyme::registerLinalgDialectAutoDiffInterface(registry); + enzyme::registerCoreDialectAutodiffInterfaces(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Enzyme modular optimizer driver", registry)); diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index 73d13a009abb..84b8b5b9540a 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -56,6 +56,25 @@ using namespace llvm; #define addAttribute addAttributeAtIndex #endif +//! Returns whether changed. +bool preserveLinkage(bool Begin, Function &F, bool Inlining = true) { + if (Begin && !F.hasFnAttribute("prev_fixup")) { + F.addFnAttr("prev_fixup"); + if (F.hasFnAttribute(Attribute::AlwaysInline)) + F.addFnAttr("prev_always_inline"); + if (F.hasFnAttribute(Attribute::NoInline)) + F.addFnAttr("prev_no_inline"); + if (Inlining) { + F.removeFnAttr(Attribute::AlwaysInline); + F.addFnAttr(Attribute::NoInline); + } + F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); + F.setLinkage(Function::LinkageTypes::ExternalLinkage); + return true; + } + return false; +} + template static void handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, @@ -237,26 +256,31 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, Fs[fn] = NewF; } + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_augment", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); + preserveLinkage(true, *Fs[2], false); Fs[0]->setMetadata( "enzyme_gradient", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[2])})); } else if (Mode == DerivativeMode::ForwardMode) { assert(numargs == 2); + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_derivative", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); } else if (Mode == DerivativeMode::ForwardModeSplit) { assert(numargs == 3); + preserveLinkage(true, *Fs[1], false); Fs[0]->setMetadata( "enzyme_augment", llvm::MDTuple::get(Fs[0]->getContext(), {llvm::ValueAsMetadata::get(Fs[1])})); + preserveLinkage(true, *Fs[2], false); Fs[0]->setMetadata( "enzyme_splitderivative", llvm::MDTuple::get(Fs[0]->getContext(), @@ -282,22 +306,6 @@ handleCustomDerivative(llvm::Module &M, llvm::GlobalVariable &g, } globalsToErase.push_back(&g); } -//! Returns whether changed. -bool preserveLinkage(bool Begin, Function &F) { - if (Begin && !F.hasFnAttribute("prev_fixup")) { - F.addFnAttr("prev_fixup"); - if (F.hasFnAttribute(Attribute::AlwaysInline)) - F.addFnAttr("prev_always_inline"); - if (F.hasFnAttribute(Attribute::NoInline)) - F.addFnAttr("prev_no_inline"); - F.addFnAttr("prev_linkage", std::to_string(F.getLinkage())); - F.setLinkage(Function::LinkageTypes::ExternalLinkage); - F.addFnAttr(Attribute::NoInline); - F.removeFnAttr(Attribute::AlwaysInline); - return true; - } - return false; -} bool preserveNVVM(bool Begin, Function &F) { bool changed = false; diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index 9ba5bd2bec5c..71d6e0910408 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -26,7 +26,6 @@ #define ENZYME_TYPE_ANALYSIS_BASE_TYPE_H 1 #include "llvm/ADT/StringRef.h" -#include "llvm/Support/ErrorHandling.h" #include /// Categories of potential types @@ -57,7 +56,8 @@ static inline std::string to_string(BaseType t) { case BaseType::Unknown: return "Unknown"; } - llvm_unreachable("unknown inttype"); + assert(0 && "unknown inttype"); + return ""; } /// Convert string to BaseType @@ -72,6 +72,7 @@ static inline BaseType parseBaseType(llvm::StringRef str) { return BaseType::Anything; if (str == "Unknown") return BaseType::Unknown; - llvm_unreachable("Unknown BaseType string"); + assert(0 && "Unknown BaseType string"); + return BaseType::Unknown; } #endif diff --git a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp index 899ce46d9f9e..2376c9b23353 100644 --- a/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp +++ b/enzyme/Enzyme/TypeAnalysis/RustDebugInfo.cpp @@ -83,7 +83,6 @@ TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) { assert(0 && "There shouldn't be non-constant-size arrays in Rust"); } } - return Result; } else if (Type.getTag() == dwarf::DW_TAG_structure_type || Type.getTag() == dwarf::DW_TAG_union_type) { DINodeArray Elements = Type.getElements(); @@ -108,11 +107,11 @@ TypeTree parseDIType(DICompositeType &Type, Instruction &I, DataLayout &DL) { firstSubTT = !firstSubTT; } } - return Result; } else { assert(0 && "Composite types other than arrays, structs and unions are not " "supported by Rust debug info parser"); } + return Result; } TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) { @@ -134,6 +133,7 @@ TypeTree parseDIType(DIDerivedType &Type, Instruction &I, DataLayout &DL) { assert(0 && "Derived types other than pointers and members are not " "supported by Rust debug info parser"); } + return {}; } TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) { @@ -151,6 +151,7 @@ TypeTree parseDIType(DIType &Type, Instruction &I, DataLayout &DL) { assert(0 && "Types other than floating-points, integers, arrays, pointers, " "slices, and structs are not supported by debug info parser"); } + return {}; } bool isU8PointerType(DIType &type) { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index f2a3dc56cdf8..2515169ae0e4 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -61,6 +61,12 @@ #include +#if LLVM_VERSION_MAJOR >= 14 +#define getAttribute getAttributeAtIndex +#define hasAttribute hasAttributeAtIndex +#define addAttribute addAttributeAtIndex +#endif + using namespace llvm; extern "C" { @@ -157,6 +163,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"__fd_sincos_1", Intrinsic::not_intrinsic}, {"sincospi", Intrinsic::not_intrinsic}, + {"cmplx_inv", Intrinsic::not_intrinsic}, // bessel functions {"j0", Intrinsic::not_intrinsic}, @@ -167,6 +174,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"yn", Intrinsic::not_intrinsic}, {"tgamma", Intrinsic::not_intrinsic}, {"lgamma", Intrinsic::not_intrinsic}, + {"logabsgamma", Intrinsic::not_intrinsic}, {"ceil", Intrinsic::ceil}, {"__nv_ceil", Intrinsic::ceil}, {"floor", Intrinsic::floor}, @@ -174,6 +182,7 @@ const llvm::StringMap LIBM_FUNCTIONS = { {"trunc", Intrinsic::trunc}, {"round", Intrinsic::round}, {"rint", Intrinsic::rint}, + {"nearbyint", Intrinsic::nearbyint}, {"remainder", Intrinsic::not_intrinsic}, {"copysign", Intrinsic::copysign}, {"nextafter", Intrinsic::not_intrinsic}, @@ -746,6 +755,9 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA, delete g2; int Off = (int)ai.getLimitedValue(); + if (auto VT = dyn_cast(Val->getType())) + if (VT->getElementType()->isIntegerTy(1)) + Off = i / 8; getConstantAnalysis(Op, TA, analysis); auto mid = analysis[Op]; @@ -1201,18 +1213,57 @@ void TypeAnalyzer::considerTBAA() { } if (CallBase *call = dyn_cast(&I)) { +#if LLVM_VERSION_MAJOR >= 14 + size_t num_args = call->arg_size(); +#else + size_t num_args = call->getNumArgOperands(); +#endif + + if (call->getAttributes().hasAttribute(AttributeList::ReturnIndex, + "enzyme_type")) { + auto attr = call->getAttributes().getAttribute( + AttributeList::ReturnIndex, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call, TT, call); + } + for (size_t i = 0; i < num_args; i++) { + if (call->getAttributes().hasParamAttr(i, "enzyme_type")) { + auto attr = call->getAttributes().getParamAttr(i, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call->getArgOperand(i), TT, call); + } + } + Function *F = call->getCalledFunction(); + + if (F) { + if (F->getAttributes().hasAttribute(AttributeList::ReturnIndex, + "enzyme_type")) { + auto attr = F->getAttributes().getAttribute( + AttributeList::ReturnIndex, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call, TT, call); + } + size_t f_num_args = F->arg_size(); + for (size_t i = 0; i < f_num_args; i++) { + if (F->getAttributes().hasParamAttr(i, "enzyme_type")) { + auto attr = F->getAttributes().getParamAttr(i, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), call->getContext()); + updateAnalysis(call->getArgOperand(i), TT, call); + } + } + } + if (auto castinst = dyn_cast(call->getCalledOperand())) { if (castinst->isCast()) if (auto fn = dyn_cast(castinst->getOperand(0))) { F = fn; } } -#if LLVM_VERSION_MAJOR >= 14 - size_t num_args = call->arg_size(); -#else - size_t num_args = call->getNumArgOperands(); -#endif if (F && F->getName().contains("__enzyme_float")) { assert(num_args == 1 || num_args == 2); assert(call->getArgOperand(0)->getType()->isPointerTy()); @@ -1854,6 +1905,7 @@ void TypeAnalyzer::visitGEPOperator(GEPOperator &gep) { MapVector VariableOffsets; bool legalOffset = collectOffset(&gep, DL, BitWidth, VariableOffsets, constOffset); + (void)legalOffset; assert(legalOffset); SmallVector, 4> idnext; @@ -2932,7 +2984,32 @@ void TypeAnalyzer::visitBinaryOperation(const DataLayout &dl, llvm::Type *T, // If ^ against 0b10000000000, the result is a float bool validXor = containsOnlyAtMostTopBit(Args[i], FT, dl); if (validXor) { - ((i == 0) ? RHS : LHS) |= TypeTree(FT).Only(-1, nullptr); + bool Legal = true; + ((i == 0) ? RHS : LHS) + .checkedOrIn(TypeTree(FT).Only(-1, nullptr), + /*pointerintsame*/ false, Legal); + + if (!Legal) { + std::string str; + raw_string_ostream ss(str); + if (!CustomErrorHandler) { + llvm::errs() << *fntypeinfo.Function->getParent() << "\n"; + llvm::errs() << *fntypeinfo.Function << "\n"; + dump(ss); + } + ss << "Illegal updateBinop (xor up) Analysis " << *origin << "\n"; + ss << " (i=" << i << ") " << (i == 0 ? "RHS" : "LHS") << " " + << ((i == 0) ? RHS : LHS).str() << " FT from ret: " << *FT + << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(str.c_str(), wrap(origin), + ErrorType::IllegalTypeAnalysis, (void *)this, + wrap(origin), nullptr); + } + EmitFailure("IllegalUpdateAnalysis", origin->getDebugLoc(), + origin, ss.str()); + report_fatal_error("Performed illegal updateAnalysis"); + } } } break; @@ -3861,7 +3938,34 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { .Only(-1, &I), &I); return; - +#if LLVM_VERSION_MAJOR >= 12 + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: + if (direction & UP) { + auto returnType = getAnalysis(&I)[{-1}]; + if (returnType == BaseType::Integer || returnType == BaseType::Pointer) { + updateAnalysis(I.getOperand(0), TypeTree(returnType).Only(-1, &I), &I); + updateAnalysis(I.getOperand(1), TypeTree(returnType).Only(-1, &I), &I); + } + } + if (direction & DOWN) { + auto opType0 = getAnalysis(I.getOperand(0))[{-1}]; + auto opType1 = getAnalysis(I.getOperand(1))[{-1}]; + if (opType0 == opType1 && + (opType0 == BaseType::Integer || opType0 == BaseType::Pointer)) { + updateAnalysis(&I, TypeTree(opType0).Only(-1, &I), &I); + } else if (opType0 == BaseType::Integer && + opType1 == BaseType::Anything) { + updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); + } else if (opType1 == BaseType::Integer && + opType0 == BaseType::Anything) { + updateAnalysis(&I, TypeTree(BaseType::Integer).Only(-1, &I), &I); + } + } + return; +#endif case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: case Intrinsic::ssub_with_overflow: @@ -4297,23 +4401,26 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } } + if (call.hasFnAttr("enzyme_ta_norecur")) + return; + Function *ci = getFunctionFromCall(&call); if (ci) { + if (ci->getAttributes().hasAttribute(AttributeList::FunctionIndex, + "enzyme_ta_norecur")) + return; + StringRef funcName = getFuncNameFromCall(&call); auto blasMetaData = extractBLAS(funcName); -#if LLVM_VERSION_MAJOR >= 16 - if (blasMetaData.has_value()) { - BlasInfo blas = blasMetaData.value(); + if (blasMetaData) { + BlasInfo blas = *blasMetaData; #include "BlasTA.inc" } -#else - if (blasMetaData.hasValue()) { - BlasInfo blas = blasMetaData.getValue(); -#include "BlasTA.inc" - } -#endif + + // Manual TT specification is non-interprocedural and already handled once + // at the start. // When compiling Enzyme against standard LLVM, and not Intel's // modified version of LLVM, the intrinsic `llvm.intel.subscript` is @@ -4513,8 +4620,10 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { return; } - if (startsWith(funcName, "_ZNKSt3__112basic_stringIcNS_11char_traitsIcEENS_" - "9allocatorIcEEE13__get_pointer")) { + if (startsWith(funcName, "_ZNKSt3__112basic_string") || + startsWith(funcName, "_ZNSt3__112basic_string") || + startsWith(funcName, "_ZNSt3__112__hash_table") || + startsWith(funcName, "_ZNKSt3__115basic_stringbuf")) { return; } @@ -4943,11 +5052,7 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } if (auto opidx = getAllocationIndexFromCall(&call)) { auto ptr = TypeTree(BaseType::Pointer); -#if LLVM_VERSION_MAJOR >= 15 - unsigned index = (size_t)opidx.value(); -#else - unsigned index = (size_t)opidx.getValue(); -#endif + unsigned index = (size_t)*opidx; if (auto CI = dyn_cast(call.getOperand(index))) { auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); auto LoadSize = CI->getZExtValue(); @@ -5231,6 +5336,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { CONSIDER(frexpl) CONSIDER2(ldexp, double, double, int) CONSIDER2(modf, double, double, double *) + CONSIDER(modff) + CONSIDER(modfl) CONSIDER2(remquo, double, double, double, int *) CONSIDER(remquof) @@ -5305,23 +5412,52 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { } else if (T->isVoidTy()) { } else if (auto ST = dyn_cast(T)) { assert(ST->getNumElements() >= 1); - for (size_t i = 1; i < ST->getNumElements(); ++i) { - assert(ST->getTypeAtIndex((unsigned)0) == ST->getTypeAtIndex(i)); - } - if (ST->getTypeAtIndex((unsigned)0)->isFloatingPointTy()) - updateAnalysis( - &call, - TypeTree(ConcreteType( - ST->getTypeAtIndex((unsigned)0)->getScalarType())) - .Only(-1, &call), - &call); - else if (ST->getTypeAtIndex((unsigned)0)->isIntegerTy()) { - updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), - &call); - } else { - llvm::errs() << *T << " - " << call << "\n"; - llvm_unreachable("Unknown type for libm"); + TypeTree TT; + auto &DL = call.getParent()->getParent()->getParent()->getDataLayout(); + for (size_t i = 0; i < ST->getNumElements(); ++i) { + auto T = ST->getTypeAtIndex(i); + ConcreteType CT(BaseType::Unknown); + + Value *vec[2] = { + ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), + ConstantInt::get(Type::getInt32Ty(call.getContext()), i)}; + auto ud = UndefValue::get(PointerType::getUnqual(ST)); + auto g2 = GetElementPtrInst::Create(ST, ud, vec); + APInt ai(DL.getIndexSizeInBits(0), 0); + g2->accumulateConstantOffset(DL, ai); + delete g2; + size_t Offset = ai.getZExtValue(); + + size_t nextOffset; + if (i + 1 == ST->getNumElements()) + nextOffset = (DL.getTypeSizeInBits(ST) + 7) / 8; + else { + Value *vec[2] = { + ConstantInt::get(Type::getInt64Ty(call.getContext()), 0), + ConstantInt::get(Type::getInt32Ty(call.getContext()), i + 1)}; + auto ud = UndefValue::get(PointerType::getUnqual(ST)); + auto g2 = GetElementPtrInst::Create(ST, ud, vec); + APInt ai(DL.getIndexSizeInBits(0), 0); + g2->accumulateConstantOffset(DL, ai); + delete g2; + nextOffset = ai.getZExtValue(); + } + + if (T->isFloatingPointTy()) { + CT = T; + } else if (T->isIntegerTy()) { + CT = BaseType::Integer; + } + if (CT != BaseType::Unknown) { + TypeTree mid = TypeTree(CT).Only(-1, &call); + TT |= mid.ShiftIndices(DL, /*init offset*/ 0, + /*maxSize*/ nextOffset - Offset, + /*addOffset*/ Offset); + } } + auto Size = (DL.getTypeSizeInBits(ST) + 7) / 8; + TT.CanonicalizeInPlace(Size, DL); + updateAnalysis(&call, TT, &call); } else if (auto AT = dyn_cast(T)) { assert(AT->getNumElements() >= 1); if (AT->getElementType()->isFloatingPointTy()) @@ -5515,7 +5651,7 @@ bool TypeAnalyzer::mustRemainInteger(Value *val, bool *returned) { FnTypeInfo TypeAnalyzer::getCallInfo(CallBase &call, Function &fn) { FnTypeInfo typeInfo(&fn); - int argnum = 0; + size_t argnum = 0; for (auto &arg : fn.args()) { if (argnum >= call.arg_size()) { typeInfo.Arguments.insert( @@ -5665,20 +5801,10 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { return TypeResults(analysis); } -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wnull-dereference" -#else -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wnull-dereference" -#endif + if (fn.Function->empty()) - return TypeResults(*(TypeAnalyzer *)nullptr); -#ifdef __clang__ -#pragma clang diagnostic pop -#else -#pragma GCC diagnostic pop -#endif + return TypeResults(nullptr); + auto res = analyzedFunctions.emplace(fn, new TypeAnalyzer(fn, *this)); auto &analysis = *res.first->second; @@ -5728,33 +5854,92 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { return TypeResults(analysis); } -TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(analyzer) {} +TypeResults::TypeResults(TypeAnalyzer &analyzer) : analyzer(&analyzer) {} +TypeResults::TypeResults(std::nullptr_t) : analyzer(nullptr) {} FnTypeInfo TypeResults::getAnalyzedTypeInfo() const { - FnTypeInfo res(analyzer.fntypeinfo.Function); - for (auto &arg : analyzer.fntypeinfo.Function->args()) { + FnTypeInfo res(analyzer->fntypeinfo.Function); + for (auto &arg : analyzer->fntypeinfo.Function->args()) { res.Arguments.insert(std::pair(&arg, query(&arg))); } res.Return = getReturnAnalysis(); - res.KnownValues = analyzer.fntypeinfo.KnownValues; + res.KnownValues = analyzer->fntypeinfo.KnownValues; return res; } FnTypeInfo TypeResults::getCallInfo(CallBase &CI, Function &fn) const { - return analyzer.getCallInfo(CI, fn); + return analyzer->getCallInfo(CI, fn); } TypeTree TypeResults::query(Value *val) const { +#ifndef NDEBUG if (auto inst = dyn_cast(val)) { - assert(inst->getParent()->getParent() == analyzer.fntypeinfo.Function); + assert(inst->getParent()->getParent() == analyzer->fntypeinfo.Function); } if (auto arg = dyn_cast(val)) { - assert(arg->getParent() == analyzer.fntypeinfo.Function); + assert(arg->getParent() == analyzer->fntypeinfo.Function); + } +#endif + return analyzer->getAnalysis(val); +} + +bool TypeResults::anyFloat(Value *val) const { + assert(val); + assert(val->getType()); + auto q = query(val); + auto dt = q[{-1}]; + if (dt != BaseType::Anything && dt != BaseType::Unknown) + return dt.isFloat(); + + size_t ObjSize = 1; + auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); + if (val->getType()->isSized()) + ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; + + for (size_t i = 0; i < ObjSize;) { + dt = q[{(int)i}]; + if (dt == BaseType::Integer) { + i++; + continue; + } + if (dt == BaseType::Pointer) { + i += dl.getPointerSize(0); + continue; + } + return true; + } + return false; +} + +bool TypeResults::anyPointer(Value *val) const { + assert(val); + assert(val->getType()); + auto q = query(val); + auto dt = q[{-1}]; + if (dt != BaseType::Anything && dt != BaseType::Unknown) + return dt == BaseType::Pointer; + + size_t ObjSize = 1; + auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout(); + if (val->getType()->isSized()) + ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8; + + for (size_t i = 0; i < ObjSize;) { + dt = q[{(int)i}]; + if (dt == BaseType::Integer) { + i++; + continue; + } + if (auto FT = dt.isFloat()) { + i += (dl.getTypeSizeInBits(FT) + 7) / 8; + continue; + } + return true; } - return analyzer.getAnalysis(val); + return false; } -void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer.dump(ss); } +void TypeResults::dump(llvm::raw_ostream &ss) const { analyzer->dump(ss); } ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, bool pointerIntSame) const { @@ -5777,7 +5962,7 @@ ConcreteType TypeResults::intType(size_t num, Value *val, bool errIfNotFound, if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; llvm::errs() << *inst->getParent()->getParent() << "\n"; - for (auto &pair : analyzer.analysis) { + for (auto &pair : analyzer->analysis) { llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() << "\n"; } @@ -5812,7 +5997,7 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, assert(val->getType()); auto q = query(val).Data0(); if (!(val->getType()->isPointerTy() || q[{}] == BaseType::Pointer)) { - llvm::errs() << *analyzer.fntypeinfo.Function << "\n"; + llvm::errs() << *analyzer->fntypeinfo.Function << "\n"; dump(); llvm::errs() << "val: " << *val << "\n"; } @@ -5837,7 +6022,7 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, } if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) { - auto &res = analyzer; + auto &res = *analyzer; if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; llvm::errs() << *inst->getParent()->getParent() << "\n"; @@ -5862,23 +6047,25 @@ ConcreteType TypeResults::firstPointer(size_t num, Value *val, Instruction *I, if (auto arg = dyn_cast(val)) { llvm::errs() << *arg->getParent() << "\n"; for (auto &pair : res.analysis) { +#ifndef NDEBUG if (auto in = dyn_cast(pair.first)) assert(in->getParent()->getParent() == arg->getParent()); +#endif llvm::errs() << "val: " << *pair.first << " - " << pair.second.str() << " int: " + to_string(res.knownIntegralValues(pair.first)) << "\n"; } } - llvm::errs() << "fn: " << *analyzer.fntypeinfo.Function << "\n"; + llvm::errs() << "fn: " << *analyzer->fntypeinfo.Function << "\n"; dump(); llvm::errs() << "could not deduce type of integer " << *val << " num:" << num << " q:" << q.str() << " \n"; llvm::DiagnosticLocation loc = - analyzer.fntypeinfo.Function->getSubprogram(); + analyzer->fntypeinfo.Function->getSubprogram(); Instruction *codeLoc = - &*analyzer.fntypeinfo.Function->getEntryBlock().begin(); + &*analyzer->fntypeinfo.Function->getEntryBlock().begin(); if (auto inst = dyn_cast(val)) { loc = inst->getDebugLoc(); codeLoc = inst; @@ -5981,15 +6168,15 @@ TypeTree defaultTypeTreeForLLVM(llvm::Type *ET, llvm::Instruction *I, } Function *TypeResults::getFunction() const { - return analyzer.fntypeinfo.Function; + return analyzer->fntypeinfo.Function; } TypeTree TypeResults::getReturnAnalysis() const { - return analyzer.getReturnAnalysis(); + return analyzer->getReturnAnalysis(); } std::set TypeResults::knownIntegralValues(Value *val) const { - return analyzer.knownIntegralValues(val); + return analyzer->knownIntegralValues(val); } std::set TypeAnalyzer::knownIntegralValues(Value *val) { diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h index fc7bc1d24dd7..19c8e2d4d19f 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h @@ -156,9 +156,10 @@ class TypeAnalysis; /// on a given function class TypeResults { public: - TypeAnalyzer &analyzer; + TypeAnalyzer *analyzer; public: + TypeResults(std::nullptr_t); TypeResults(TypeAnalyzer &analyzer); ConcreteType intType(size_t num, llvm::Value *val, bool errIfNotFound = true, bool pointerIntSame = false) const; @@ -174,6 +175,16 @@ class TypeResults { /// The TypeTree of a particular Value TypeTree query(llvm::Value *val) const; + /// Whether any part of the top level register can contain a float + /// e.g. { i64, float } can contain a float, but { i64, i8* } would not. + // Of course, here we compute with type analysis rather than llvm type + bool anyFloat(llvm::Value *val) const; + + /// Whether any part of the top level register can contain a pointer + /// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not. + // Of course, here we compute with type analysis rather than llvm type + bool anyPointer(llvm::Value *val) const; + /// The TypeInfo calling convention FnTypeInfo getAnalyzedTypeInfo() const; @@ -248,6 +259,8 @@ class TypeAnalyzer : public llvm::InstVisitor { FnTypeInfo getCallInfo(llvm::CallBase &CI, llvm::Function &fn); + TypeAnalyzer(TypeAnalysis &TA); + TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA, uint8_t direction = BOTH); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index b69b83e5b355..22a4c25d9361 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -83,6 +83,92 @@ class TypeTree : public std::enable_shared_from_this { } } + static TypeTree parse(llvm::StringRef str, llvm::LLVMContext &ctx) { + using namespace llvm; + assert(str[0] == '{'); + str = str.substr(1); + + TypeTree Result; + while (true) { + while (str[0] == ' ') + str = str.substr(1); + if (str[0] == '}') + break; + + assert(str[0] == '['); + str = str.substr(1); + + std::vector idxs; + while (true) { + while (str[0] == ' ') + str = str.substr(1); + if (str[0] == ']') { + str = str.substr(1); + break; + } + + int idx; + bool failed = str.consumeInteger(10, idx); + (void)failed; + assert(!failed); + idxs.push_back(idx); + + while (str[0] == ' ') + str = str.substr(1); + + if (str[0] == ',') { + str = str.substr(1); + } + } + + while (str[0] == ' ') + str = str.substr(1); + + assert(str[0] == ':'); + str = str.substr(1); + + while (str[0] == ' ') + str = str.substr(1); + + auto endval = str.find(','); + auto endval2 = str.find('}'); + auto endval3 = str.find(' '); + + if (endval2 != StringRef::npos && + (endval == StringRef::npos || endval2 < endval)) + endval = endval2; + if (endval3 != StringRef::npos && + (endval == StringRef::npos || endval3 < endval)) + endval = endval3; + assert(endval != StringRef::npos); + + auto tystr = str.substr(0, endval); + str = str.substr(endval); + + ConcreteType CT(tystr, ctx); + Result.mapping.emplace(idxs, CT); + if (Result.minIndices.size() < idxs.size()) { + for (size_t i = Result.minIndices.size(), end = idxs.size(); i < end; + ++i) { + Result.minIndices.push_back(idxs[i]); + } + } + for (size_t i = 0, end = idxs.size(); i < end; ++i) { + if (idxs[i] < Result.minIndices[i]) + Result.minIndices[i] = idxs[i]; + } + + while (str[0] == ' ') + str = str.substr(1); + + if (str[0] == ',') { + str = str.substr(1); + } + } + + return Result; + } + /// Utility helper to lookup the mapping const ConcreteTypeMapType &getMapping() const { return mapping; } @@ -343,11 +429,13 @@ class TypeTree : public std::enable_shared_from_this { /// Whether this TypeTree contains any information bool isKnown() const { +#ifndef NDEBUG for (const auto &pair : mapping) { // we should assert here as we shouldn't keep any unknown maps for // efficiency assert(pair.second.isKnown()); } +#endif return mapping.size() != 0; } @@ -614,22 +702,22 @@ class TypeTree : public std::enable_shared_from_this { staging[next][pair.second].insert(pair.first[0]); } - mapping.clear(); + // TypeTree mappings which did not get combined + std::map, ConcreteType> unCombinedToAdd; - for (auto &pair : staging) { + // TypeTree mappings which did get combined into an outer -1 + std::map, ConcreteType> combinedToAdd; + + for (const auto &pair : staging) { auto &pnext = pair.first; - for (auto &pair2 : pair.second) { + for (const auto &pair2 : pair.second) { auto dt = pair2.first; const auto &set = pair2.second; - // llvm::errs() << " - set: {"; - // for(auto s : set) llvm::errs() << s << ", "; - // llvm::errs() << "} len=" << len << "\n"; - - bool legalCombine = set.count(-1); + bool legalCombine = false; // See if we can canonicalize the outermost index into a -1 - if (!legalCombine) { + if (!set.count(-1)) { size_t chunk = 1; if (pnext.size() > 0) { chunk = dl.getPointerSizeInBits() / 8; @@ -657,15 +745,38 @@ class TypeTree : public std::enable_shared_from_this { next.push_back(v); if (legalCombine) { - insert(next, dt, /*intsAreLegalPointerSub*/ true); + combinedToAdd.emplace(next, dt); } else { for (auto e : set) { next[0] = e; - insert(next, dt); + unCombinedToAdd.emplace(next, dt); } } } } + + // If we combined nothing, just return since there are no + // changes. + if (combinedToAdd.size() == 0) { + return; + } + + // Non-combined ones do not conflict, since they were already in + // a TT which we can assume contained no conflicts. + mapping = std::move(unCombinedToAdd); + minIndices[0] = -1; + + // Fusing several terms into a minus one can create a conflict + // if the prior minus one was already in the map + // time, or also generated by fusion. + // E.g. {-1:Anything, [0]:Pointer} on 8 -> create a [-1]:Pointer + // which conflicts + // Alternatively [-1,-1,-1]:Pointer, and generated a [-1,0,-1] fusion + for (const auto &pair : combinedToAdd) { + insert(pair.first, pair.second); + } + + return; } /// Keep only pointers (or anything's) to a repeated value (represented by -1) @@ -722,17 +833,51 @@ class TypeTree : public std::enable_shared_from_this { } /// Replace mappings in the range in [offset, offset+maxSize] with those in - // [addOffset, addOffset + maxSize]. In other worse, select all mappings in + // [addOffset, addOffset + maxSize]. In other words, select all mappings in // [offset, offset+maxSize] then add `addOffset` TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset, const int maxSize, size_t addOffset = 0) const { + + // If we have no terms 1+ layer deep return the current result as a shift + // won't change anything. This also makes the latercode simpler as it + // can assume at least a first index exists. + if (minIndices.size() == 0) + return *this; + + // If we have no size in return, simply return an empty type tree. Again + // this simplifies later code which can assume that a minus one expantion + // will always result in an added variable (which would not be the case + // on a size == 0). + if (maxSize == 0) + return TypeTree(); + TypeTree Result; + // The normal orIn / insert methods do collision checking, which is slow + // (and presently O(n)). This is because an expansion of a -1 which could + // conflict with a fixed value. Consider calling this + // ShiftIndicies(offset=0, maxSize=2, addOffset=0, tt={[-1]:Integer, + // [1]:Anything}) the -1 would expand to [0]:Int, [1]:Int, which would need + // to be merged with [1]:Anything + // + // The only possible values which can cause a conflict are minus -1's. + // As a result, we start with a fast insertion (aka without check) of + // non-expanded values, since they just do a literal shift which needs no + // extra checking, besides bounds checks. + // + // Since we're doing things manually, we also need to manually preserve TT + // invariants. Specifically, TT limits all values to have offsets < + // MAX_OFFSET, unless it is the smallest offset at that depth. (e.g. so we + // can still hava typetree {[123456]:Int}, even if limit is 100). + // + // First compute the minimum 0th index to be kept. + Result.minIndices.resize(minIndices.size(), INT_MAX); + for (const auto &pair : mapping) { if (pair.first.size() == 0) { if (pair.second == BaseType::Pointer || pair.second == BaseType::Anything) { - Result.insert(pair.first, pair.second); + Result.mapping.emplace(pair.first, pair.second); continue; } @@ -741,55 +886,152 @@ class TypeTree : public std::enable_shared_from_this { llvm_unreachable("ShiftIndices called on a nonpointer/anything"); } - std::vector next(pair.first); + int next0 = pair.first[0]; - if (next[0] == -1) { + if (next0 == -1) { if (maxSize == -1) { // Max size does not clip the next index // If we have a follow up offset add, we lose the -1 since we only // represent [0, inf) with -1 not the [addOffset, inf) required here if (addOffset != 0) { - next[0] = addOffset; + next0 = addOffset; } } else { - // This needs to become 0...maxSize as seen below + // We're going to insert addOffset + 0...maxSize so the new minIndex + // is addOffset + Result.minIndices[0] = addOffset; + for (size_t i = 1, sz = pair.first.size(); i < sz; i++) + if (pair.first[i] < Result.minIndices[i]) + Result.minIndices[i] = pair.first[i]; + continue; + } + } else { + // Too small for range + if (next0 < offset) { + continue; + } + next0 -= offset; + + if (maxSize != -1) { + if (next0 >= maxSize) + continue; + } + + next0 += addOffset; + } + if (next0 < Result.minIndices[0]) + Result.minIndices[0] = next0; + for (size_t i = 1, sz = pair.first.size(); i < sz; i++) + if (pair.first[i] < Result.minIndices[i]) + Result.minIndices[i] = pair.first[i]; + } + + // Max depth of actual inserted values + size_t maxInsertedDepth = 0; + + // Insert all + for (const auto &pair : mapping) { + if (pair.first.size() == 0) + continue; + + int next0 = pair.first[0]; + + if (next0 == -1) { + if (maxSize == -1) { + // Max size does not clip the next index + + // If we have a follow up offset add, we lose the -1 since we only + // represent [0, inf) with -1 not the [addOffset, inf) required here + if (addOffset != 0) { + next0 = addOffset; + } + + } else { + // This needs to become 0...maxSize handled separately as it is the + // only insertion that could have collisions + continue; } } else { // Too small for range - if (next[0] < offset) { + if (next0 < offset) { continue; } - next[0] -= offset; + next0 -= offset; if (maxSize != -1) { - if (next[0] >= maxSize) + if (next0 >= maxSize) continue; } - next[0] += addOffset; + next0 += addOffset; } - size_t chunk = 1; - auto op = operator[]({pair.first[0]}); - if (auto flt = op.isFloat()) { - chunk = dl.getTypeSizeInBits(flt) / 8; - } else if (op == BaseType::Pointer) { - chunk = dl.getPointerSizeInBits() / 8; + // If after moving this would not merit being kept for being a min index + // or being within the max type offset, skip it. + if (next0 > MaxTypeOffset) { + bool minIndex = next0 == Result.minIndices[0]; + if (!minIndex) + for (size_t i = 1; i < pair.first.size(); i++) { + if (pair.first[i] == Result.minIndices[i]) { + minIndex = true; + break; + } + } + if (!minIndex) + continue; } - if (next[0] == -1 && maxSize != -1) { + std::vector next(pair.first); + next[0] = next0; + Result.mapping.emplace(next, pair.second); + if (next.size() > maxInsertedDepth) + maxInsertedDepth = next.size(); + } + + // Insert and expand the minus one, if needed + if (maxSize != -1) + for (const auto &pair : mapping) { + if (pair.first.size() == 0) + continue; + if (pair.first[0] != -1) + continue; + + size_t chunk = 1; + std::vector next(pair.first); + auto op = operator[]({next[0]}); + if (auto flt = op.isFloat()) { + chunk = dl.getTypeSizeInBits(flt) / 8; + } else if (op == BaseType::Pointer) { + chunk = dl.getPointerSizeInBits() / 8; + } auto offincr = (chunk - offset % chunk) % chunk; + bool inserted = false; for (int i = offincr; i < maxSize; i += chunk) { next[0] = i + addOffset; - Result.orIn(next, pair.second); + ConcreteType prev(pair.second); + // We can use faster checks here, since we know there can be no + // -1's that we would conflict with, only conflicts from previous + // fixed value insertions. + auto found = Result.mapping.find(next); + if (found != Result.mapping.end()) { + // orIn returns if changed, update the value in the map if so + // with the new value. + if (prev.orIn(found->second, /*pointerIntSame*/ false)) + found->second = prev; + } else { + Result.mapping.emplace(next, pair.second); + } + inserted = true; } - } else { - Result.orIn(next, pair.second); + if (inserted && next.size() > maxInsertedDepth) + maxInsertedDepth = next.size(); } - } + // Resize minIndices down if we dropped any higher-depth indices for being + // out of scope. + Result.minIndices.resize(maxInsertedDepth); return Result; } @@ -1095,7 +1337,6 @@ class TypeTree : public std::enable_shared_from_this { if (found != RHS.mapping.end()) { RightCT = found->second; } - bool SubLegal = true; changed |= CT.binopIn(SubLegal, RightCT, Op); if (!SubLegal) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index ff44cbaa715d..283460673922 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2264,7 +2264,258 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, llvm_unreachable("unknown inst2"); } -Function *GetFunctionFromValue(Value *fn) { +// Find the base pointer of ptr and the offset in bytes from the start of +// the returned base pointer to this value. +AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) { + offset = 0; + while (true) { + if (auto CI = dyn_cast(ptr)) { + ptr = CI->getOperand(0); + continue; + } + if (auto CI = dyn_cast(ptr)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + return nullptr; + } + offset += Offset.getZExtValue(); + ptr = CI->getOperand(0); + continue; + } + if (isa(ptr)) { + break; + } + if (auto LI = dyn_cast(ptr)) { + if (auto S = simplifyLoad(LI)) { + ptr = S; + continue; + } + } + return nullptr; + } + return cast(ptr); +} + +// Find all user instructions of AI, returning tuples of Unlike a simple get users, this will recurse through any +// constant gep offsets and casts +SmallVector, 1> +findAllUsersOf(Value *AI) { + SmallVector, 1> todo; + todo.emplace_back(AI, 0); + + SmallVector, 1> users; + while (todo.size()) { + auto pair = todo.pop_back_val(); + Value *ptr = pair.first; + size_t suboff = pair.second; + + for (auto U : ptr->users()) { + if (auto CI = dyn_cast(U)) { + todo.emplace_back(CI, suboff); + continue; + } + if (auto CI = dyn_cast(U)) { + auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout(); + MapVector VariableOffsets; + auto width = sizeof(size_t) * 8; + APInt Offset(width, 0); + bool success = collectOffset(cast(CI), DL, width, + VariableOffsets, Offset); + + if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) { + users.emplace_back(cast(U), ptr, suboff); + continue; + } + todo.emplace_back(CI, suboff + Offset.getZExtValue()); + continue; + } + users.emplace_back(cast(U), ptr, suboff); + continue; + } + } + return users; +} + +// Given a pointer, find all values of size `valSz` which could be loaded from +// that pointer when indexed at offset. If it is impossible to guarantee that +// the set contains all such values, set legal to false +SmallVector getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset, + size_t valSz, bool &legal) { + SmallVector options; + + auto todo = findAllUsersOf(ptr0); + std::set> seen; + + while (todo.size()) { + auto pair = todo.pop_back_val(); + if (seen.count(pair)) + continue; + seen.insert(pair); + Instruction *U = std::get<0>(pair); + Value *ptr = std::get<1>(pair); + size_t suboff = std::get<2>(pair); + + // Read only users do not set the memory inside of ptr + if (isa(U)) { + continue; + } + if (auto MTI = dyn_cast(U)) + if (MTI->getOperand(0) != ptr) { + continue; + } + if (auto I = dyn_cast(U)) { + if (!I->mayWriteToMemory() && I->getType()->isVoidTy()) + continue; + } + + if (auto SI = dyn_cast(U)) { + auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout(); + + // We are storing into the ptr + if (SI->getPointerOperand() == ptr) { + auto storeSz = + (DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) / + 8; + // If store is before the load would start + if (storeSz + suboff <= offset) + continue; + // if store starts after load would start + if (offset + valSz <= suboff) + continue; + + if (valSz == storeSz) { + options.push_back(SI->getValueOperand()); + continue; + } + } + + // We capture our pointer of interest, if it is stored into an alloca, + // all loads of said alloca would potentially store into. + if (SI->getValueOperand() == ptr) { + if (suboff == 0) { + size_t mid_offset = 0; + if (auto AI2 = + getBaseAndOffset(SI->getPointerOperand(), mid_offset)) { + bool sublegal = true; + auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8; + auto subPtrs = + getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal); + if (!sublegal) { + legal = false; + return options; + } + for (auto subPtr : subPtrs) { + for (const auto &pair3 : findAllUsersOf(subPtr)) { + todo.emplace_back(pair3); + } + } + continue; + } + } + } + } + + if (auto II = dyn_cast(U)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + } + + // If we copy into the ptr at a location that includes the offset, consider + // all sub uses + if (auto MTI = dyn_cast(U)) { + if (auto CI = dyn_cast(MTI->getLength())) { + if (MTI->getOperand(0) == ptr && suboff == 0 && + CI->getValue().uge(offset + valSz)) { + size_t midoffset = 0; + auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset); + if (!AI2) { + legal = false; + return options; + } + if (midoffset != 0) { + legal = false; + return options; + } + for (const auto &pair3 : findAllUsersOf(AI2)) { + todo.emplace_back(pair3); + } + continue; + } + } + } + + legal = false; + return options; + } + + return options; +} + +// Perform mem2reg/sroa to identify the innermost value being represented. +Value *simplifyLoad(Value *V, size_t valSz) { + if (auto LI = dyn_cast(V)) { + if (valSz == 0) { + auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8; + } + + Value *ptr = LI->getPointerOperand(); + size_t offset = 0; + + if (auto ptr2 = simplifyLoad(ptr)) { + ptr = ptr2; + } + auto AI = getBaseAndOffset(ptr, offset); + if (!AI) { + return nullptr; + } + + bool legal = true; + auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal); + + if (!legal) { + return nullptr; + } + std::set res; + for (auto opt : opts) { + Value *v2 = simplifyLoad(opt, valSz); + if (v2) + res.insert(v2); + else + res.insert(opt); + } + if (res.size() != 1) { + return nullptr; + } + Value *retval = *res.begin(); + return retval; + } + if (auto EVI = dyn_cast(V)) { + bool allZero = true; + for (auto idx : EVI->getIndices()) { + if (idx != 0) + allZero = false; + } + if (valSz == 0) { + auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout(); + valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8; + } + if (allZero) + if (auto LI = dyn_cast(EVI->getAggregateOperand())) { + return simplifyLoad(LI, valSz); + } + } + return nullptr; +} + +Value *GetFunctionValFromValue(Value *fn) { while (!isa(fn)) { if (auto ci = dyn_cast(fn)) { fn = ci->getOperand(0); @@ -2294,6 +2545,7 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + val = GetFunctionValFromValue(val); if (isa(val)) { fn = val; continue; @@ -2315,6 +2567,14 @@ Function *GetFunctionFromValue(Value *fn) { } if (ret.size() == 1) { auto val = *ret.begin(); + while (isa(val)) { + auto v2 = simplifyLoad(val); + if (v2) { + val = v2; + continue; + } + break; + } if (isa(val)) { fn = val; continue; @@ -2326,73 +2586,18 @@ Function *GetFunctionFromValue(Value *fn) { } } } - if (auto LI = dyn_cast(fn)) { - auto obj = getBaseObject(LI->getPointerOperand()); - if (isa(obj)) { - std::set> done; - SmallVector, 1> todo; - Value *stored = nullptr; - bool legal = true; - for (auto U : obj->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, obj)); - else { - legal = false; - break; - } - } - while (legal && todo.size()) { - auto tup = todo.pop_back_val(); - if (done.count(tup)) - continue; - done.insert(tup); - auto cur = tup.first; - auto prev = tup.second; - if (auto SI = dyn_cast(cur)) - if (SI->getPointerOperand() == prev) { - if (stored == SI->getValueOperand()) - continue; - else if (stored == nullptr) { - stored = SI->getValueOperand(); - continue; - } else { - legal = false; - break; - } - } - - if (isPointerArithmeticInst(cur, /*includephi*/ true)) { - for (auto U : cur->users()) { - if (auto I = dyn_cast(U)) - todo.push_back(std::make_pair(I, cur)); - else { - legal = false; - break; - } - } - continue; - } - - if (isa(cur)) - continue; - - if (!cur->mayWriteToMemory() && cur->getType()->isVoidTy()) - continue; - - legal = false; - break; - } - - if (legal && stored) { - fn = stored; - continue; - } - } + if (auto S = simplifyLoad(fn)) { + fn = S; + continue; } break; } - return dyn_cast(fn); + return fn; +} + +Function *GetFunctionFromValue(Value *fn) { + return dyn_cast(GetFunctionValFromValue(fn)); } #if LLVM_VERSION_MAJOR >= 16 diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 7b46fd9a2e83..5a4b3c31cef7 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -318,11 +318,14 @@ enum class ReturnType { /// Potential differentiable argument classifications enum class DIFFE_TYPE { - OUT_DIFF = 0, // add differential to an output struct - DUP_ARG = 1, // duplicate the argument and store differential inside - CONSTANT = 2, // no differential + OUT_DIFF = 0, // add differential to an output struct. Only for scalar values + // in ReverseMode variants. + DUP_ARG = 1, // duplicate the argument and store differential inside. + // For references, pointers, or integers in ReverseMode variants. + // For all types in ForwardMode variants. + CONSTANT = 2, // no differential. Usable everywhere. DUP_NONEED = 3 // duplicate this argument and store differential inside, but - // don't need the forward + // don't need the forward. Same as DUP_ARG otherwise. }; enum class BATCH_TYPE { @@ -1155,6 +1158,7 @@ static inline llvm::Optional getAllocationIndexFromCall(T *op) bool b = AttrList.getAttribute("enzyme_allocator") .getValueAsString() .getAsInteger(10, res); + (void)b; assert(!b); #if LLVM_VERSION_MAJOR >= 16 return std::optional(res); @@ -1169,6 +1173,7 @@ static inline llvm::Optional getAllocationIndexFromCall(T *op) bool b = called->getFnAttribute("enzyme_allocator") .getValueAsString() .getAsInteger(10, res); + (void)b; assert(!b); #if LLVM_VERSION_MAJOR >= 16 return std::optional(res); @@ -1225,6 +1230,7 @@ static inline std::vector getDeallocationIndicesFromCall(T *op) { for (auto ind : inds) { ssize_t Result; bool b = ind.getAsInteger(10, Result); + (void)b; assert(!b); vinds.push_back(Result); } @@ -1242,6 +1248,8 @@ void ErrorIfRuntimeInactive(llvm::IRBuilder<> &B, llvm::Value *primal, llvm::Function *GetFunctionFromValue(llvm::Value *fn); +llvm::Value *simplifyLoad(llvm::Value *LI, size_t valSz = 0); + static inline bool shouldDisableNoWrite(const llvm::CallInst *CI) { auto F = getFunctionFromCall(CI); auto funcName = getFuncNameFromCall(CI); @@ -1352,10 +1360,11 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) { auto AttrList = Call->getAttributes().getAttributes( llvm::AttributeList::FunctionIndex); if (AttrList.hasAttribute("enzyme_pointermath")) { - size_t res; + size_t res = 0; bool failed = AttrList.getAttribute("enzyme_pointermath") .getValueAsString() .getAsInteger(10, res); + (void)failed; assert(!failed); V = Call->getArgOperand(res); continue; @@ -1383,10 +1392,11 @@ static inline llvm::Value *getBaseObject(llvm::Value *V) { auto AttrList = fn->getAttributes().getAttributes( llvm::AttributeList::FunctionIndex); if (AttrList.hasAttribute("enzyme_pointermath")) { - size_t res; + size_t res = 0; bool failed = AttrList.getAttribute("enzyme_pointermath") .getValueAsString() .getAsInteger(10, res); + (void)failed; assert(!failed); V = Call->getArgOperand(res); continue; diff --git a/enzyme/test/ActivityAnalysis/integration.ll b/enzyme/test/ActivityAnalysis/integration.ll new file mode 100644 index 000000000000..28ce009a9b5b --- /dev/null +++ b/enzyme/test/ActivityAnalysis/integration.ll @@ -0,0 +1,93 @@ +; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %OPnewLoadEnzyme -passes="print-activity-analysis" -activity-analysis-func=f.preprocess -S | FileCheck %s; fi + +declare void @free(ptr) + +declare ptr @malloc(i64) + +; This function just returns 2*input, its derivate should be 2.0. +define void @f.preprocess(ptr %param, i64 %mallocsize, ptr %res) { + + ; arithmetic block, changing anything here makes the bug go away + %buffer1 = call ptr @malloc(i64 %mallocsize) + %tmp = call ptr @malloc(i64 72) + %ptrtoint = ptrtoint ptr %tmp to i64 + %and = and i64 %ptrtoint, -64 + %inttoptr = inttoptr i64 %and to ptr + %loadarg = load double, ptr %param + %storedargmul = fmul double %loadarg, 4.000000e+00 + store double %storedargmul, ptr %inttoptr + call void @free(ptr %tmp) + store double %storedargmul, ptr %buffer1 + + ; prep arg 0 by setting the aligned pointer to the input + %arg0 = alloca { ptr, ptr, i64 } + %arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1 + store ptr %param, ptr %arg0_aligned + + ; prep arg 1 by setting the aligned pointer to buffer1 + %arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] } + %arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1 + store ptr %buffer1, ptr %arg1_aligned + + ; prep arg 2 by setting the aligned pointer to buffer2 + %arg2 = alloca { ptr, ptr, i64 } + %arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1 + %buffer2 = call ptr @malloc(i64 8) + store ptr %buffer2, ptr %arg2_aligned + + ; nested call, required for bug + call void @nested(ptr %arg0, ptr %arg1, ptr %arg2) + + ; return a result from this function, needs to be positioned after arithmetic block for bug + %x = load double, ptr %param + %y = fmul double %x, 2.0 + store double %y, ptr %res + + ret void +} + +; Identity function, 2nd argument required for bug (but not used) +define void @nested(ptr %arg0, ptr %arg1, ptr %arg2) { + + ; load aligned pointer from %arg0 & load argument value + %loadarg = load { ptr, ptr, i64 }, ptr %arg0 + %extractarg = extractvalue { ptr, ptr, i64 } %loadarg, 1 + %loadextractarg = load double, ptr %extractarg + + ; load aligned pointer from %arg2 & store result value + %loadarg2 = load { ptr, ptr, i64 }, ptr %arg2 + %extractarg2 = extractvalue { ptr, ptr, i64 } %loadarg2, 1 + store double %loadextractarg, ptr %extractarg2 + + ret void +} + +; CHECK: ptr %param: icv:0 +; CHECK-NEXT: i64 %mallocsize: icv:1 +; CHECK-NEXT: ptr %res: icv:0 + +; CHECK: %buffer1 = call ptr @malloc(i64 %mallocsize): icv:0 ici:1 +; CHECK-NEXT: %tmp = call ptr @malloc(i64 72): icv:1 ici:1 +; CHECK-NEXT: %ptrtoint = ptrtoint ptr %tmp to i64: icv:1 ici:1 +; CHECK-NEXT: %and = and i64 %ptrtoint, -64: icv:1 ici:1 +; CHECK-NEXT: %inttoptr = inttoptr i64 %and to ptr: icv:1 ici:1 +; CHECK-NEXT: %loadarg = load double, ptr %param, align 8: icv:0 ici:0 +; CHECK-NEXT: %storedargmul = fmul double %loadarg, 4.000000e+00: icv:0 ici:0 +; CHECK-NEXT: store double %storedargmul, ptr %inttoptr, align 8: icv:1 ici:1 +; CHECK-NEXT: call void @free(ptr %tmp): icv:1 ici:1 +; CHECK-NEXT: store double %storedargmul, ptr %buffer1, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg0 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg0_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg0, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: store ptr %param, ptr %arg0_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg1 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg1_aligned = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %arg1, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: store ptr %buffer1, ptr %arg1_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: %arg2 = alloca { ptr, ptr, i64 }, align 8: icv:0 ici:1 +; CHECK-NEXT: %arg2_aligned = getelementptr inbounds { ptr, ptr, i64 }, ptr %arg2, i64 0, i32 1: icv:0 ici:1 +; CHECK-NEXT: %buffer2 = call ptr @malloc(i64 8): icv:0 ici:1 +; CHECK-NEXT: store ptr %buffer2, ptr %arg2_aligned, align 8: icv:1 ici:0 +; CHECK-NEXT: call void @nested(ptr %arg0, ptr %arg1, ptr %arg2): icv:1 ici:0 +; CHECK-NEXT: %x = load double, ptr %param, align 8: icv:0 ici:0 +; CHECK-NEXT: %y = fmul double %x, 2.000000e+00: icv:0 ici:0 +; CHECK-NEXT: store double %y, ptr %res, align 8: icv:1 ici:0 +; CHECK-NEXT: ret void: icv:1 ici:1 diff --git a/enzyme/test/BUILD b/enzyme/test/BUILD new file mode 100644 index 000000000000..47143966810d --- /dev/null +++ b/enzyme/test/BUILD @@ -0,0 +1,32 @@ +# Enzyme tests. + +load("@llvm-project//llvm:lit_test.bzl", "package_path") +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +# Generates lit config input file by applying path placeholder substitutions +# similar to the configure_lit_site_cfg CMake macro. +expand_template( + name = "lit_site_cfg_py", + testonly = True, + out = "lit.site.cfg.py", + substitutions = { + "@LLVM_VERSION_MAJOR@": "18", + "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", + "@LLVM_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_TOOLS_BINARY_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@LLVM_LIBS_DIR@": package_path("@llvm-project//llvm:BUILD"), + "@ENZYME_SOURCE_DIR@": "", + "@ENZYME_BINARY_DIR@": "", + "@TARGET_TRIPLE@": "", + "@TARGETS_TO_BUILD@": "ALL", + "@LLVM_SHLIBEXT@": ".so", + }, + template = "lit.site.cfg.py.in", + visibility = [":__subpackages__"], +) + +exports_files( + ["lit.cfg.py"], + visibility = [":__subpackages__"], +) + diff --git a/enzyme/test/Enzyme/ForwardMode/acosh.ll b/enzyme/test/Enzyme/ForwardMode/acosh.ll new file mode 100644 index 000000000000..90108c74de17 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/acosh.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @acosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @acosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fsub fast double %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %1) +; CHECK-NEXT: %3 = fdiv fast double %"x'", %2 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinh.ll b/enzyme/test/Enzyme/ForwardMode/asinh.ll new file mode 100644 index 000000000000..73377c6edfd5 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinh.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @asinh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @asinh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fadd fast double %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast double @llvm.sqrt.f64(double %1) +; CHECK-NEXT: %3 = fdiv fast double %"x'", %2 +; CHECK-NEXT: ret double %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinhf.ll b/enzyme/test/Enzyme/ForwardMode/asinhf.ll new file mode 100644 index 000000000000..edd9959120d9 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinhf.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @asinhf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_fwddiff(float (float)* nonnull @tester, float %x, float 1.0) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @asinhf(float) + +; Function Attrs: nounwind +declare float @__enzyme_fwddiff(float (float)*, ...) + +; CHECK: define internal float @fwddiffetester(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast float %x, %x +; CHECK-NEXT: %1 = fadd fast float %0, 1.000000e+00 +; CHECK-NEXT: %2 = call fast float @llvm.sqrt.f32(float %1) +; CHECK-NEXT: %3 = fdiv fast float %"x'", %2 +; CHECK-NEXT: ret float %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/asinhl.ll b/enzyme/test/Enzyme/ForwardMode/asinhl.ll new file mode 100644 index 000000000000..80929fd665c2 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/asinhl.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @asinhl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_fwddiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x, x86_fp80 0xK3FFF8000000000000000) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @asinhl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_fwddiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal x86_fp80 @fwddiffetester(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast x86_fp80 %x, %x +; CHECK-NEXT: %1 = fadd fast x86_fp80 %0, 0xK3FFF8000000000000000 +; CHECK-NEXT: %2 = call fast x86_fp80 @llvm.sqrt.f80(x86_fp80 %1) +; CHECK-NEXT: %3 = fdiv fast x86_fp80 %"x'", %2 +; CHECK-NEXT: ret x86_fp80 %3 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/atanh.ll b/enzyme/test/Enzyme/ForwardMode/atanh.ll new file mode 100644 index 000000000000..358ab107e140 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/atanh.ll @@ -0,0 +1,30 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @atanh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @atanh(double) + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fmul fast double %x, %x +; CHECK-NEXT: %1 = fsub fast double 1.000000e+00, %0 +; CHECK-NEXT: %2 = fdiv fast double %"x'", %1 +; CHECK-NEXT: ret double %2 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ForwardMode/coshf.ll b/enzyme/test/Enzyme/ForwardMode/coshf.ll new file mode 100644 index 000000000000..b94270380e52 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/coshf.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @coshf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_fwddiff(float (float)* nonnull @tester, float %x, float 1.0) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @coshf(float) + +; Function Attrs: nounwind +declare float @__enzyme_fwddiff(float (float)*, ...) + +; CHECK: define internal float @fwddiffetester(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast float @sinhf(float %x) +; CHECK-NEXT: %1 = fmul fast float %"x'", %0 +; CHECK-NEXT: ret float %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/coshl.ll b/enzyme/test/Enzyme/ForwardMode/coshl.ll new file mode 100644 index 000000000000..4a6e3054f0fc --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/coshl.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @coshl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_fwddiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x, x86_fp80 0xK3FFF8000000000000000) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @coshl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_fwddiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal x86_fp80 @fwddiffetester(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast x86_fp80 @sinhl(x86_fp80 %x) +; CHECK-NEXT: %1 = fmul fast x86_fp80 %"x'", %0 +; CHECK-NEXT: ret x86_fp80 %1 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/fdim.ll b/enzyme/test/Enzyme/ForwardMode/fdim.ll new file mode 100644 index 000000000000..7d333686f7f4 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/fdim.ll @@ -0,0 +1,32 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s ; fi + +declare double @fdim(double, double) + +define double @tester(double %x, double %y) { +entry: + %0 = call double @fdim(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @tester, double %x, double 10.0, double %y, double 1.0) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fwddiff(double (double, double)*, ...) + +; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = fcmp fast olt double %x, %y +; CHECK-NEXT: %1 = select fast i1 %0, double 0.000000e+00, double %"x'" +; CHECK-NEXT: %2 = fcmp fast olt double %x, %y +; CHECK-NEXT: %3 = fneg fast double %"y'" +; CHECK-NEXT: %4 = select fast i1 %2, double 0.000000e+00, double %3 +; CHECK-NEXT: %5 = fadd fast double %1, %4 +; CHECK-NEXT: ret double %5 +; CHECK-NEXT: } + + diff --git a/enzyme/test/Enzyme/ForwardMode/frexp.ll b/enzyme/test/Enzyme/ForwardMode/frexp.ll index 9e421ec21e3a..f9f0916a09b9 100644 --- a/enzyme/test/Enzyme/ForwardMode/frexp.ll +++ b/enzyme/test/Enzyme/ForwardMode/frexp.ll @@ -1,10 +1,11 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s -declare double @frexp(double, i32*) declare double @__enzyme_fwddiff(i8*, ...) declare float @__enzyme_fwddifff(i8*, ...) +declare x86_fp80 @__enzyme_fwddiffl(i8*, ...) +declare double @frexp(double, i32*) define double @test(double %x) { entry: %exp = alloca i32, align 4 @@ -32,6 +33,20 @@ entry: ret float %call } +declare x86_fp80 @frexpl(x86_fp80, i32*) +define x86_fp80 @testl(x86_fp80 %x) { +entry: + %exp = alloca i32, align 4 + %call = call x86_fp80 @frexpl(x86_fp80 %x, i32* %exp) + ret x86_fp80 %call +} + +define x86_fp80 @dtestl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} + ; CHECK: define internal double @fwddiffetest(double %x, double %"x'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = bitcast double %x to i64 @@ -51,3 +66,13 @@ entry: ; CHECK-NEXT: %4 = fdiv fast float %"x'", %3 ; CHECK-NEXT: ret float %4 ; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast x86_fp80 %x to i80 +; CHECK-NEXT: %1 = and i80 604453686435277732577280, %0 +; CHECK-NEXT: %2 = bitcast i80 %1 to x86_fp80 +; CHECK-NEXT: %3 = fmul fast x86_fp80 %2, 0xK40008000000000000000 +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %"x'", %3 +; CHECK-NEXT: ret x86_fp80 %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll b/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll new file mode 100644 index 000000000000..44f246cdfbe0 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/insertdiffuse.ll @@ -0,0 +1,25 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +define { double, i64 } @julia_logabsgamma_3264_inner.1(double %x, i64 %z) { +entry: + %iadd = add i64 %z, 1 + %.fca.0.insert = insertvalue { double, i64 } undef, double %x, 0 + %.fca.1.insert = insertvalue { double, i64 } %.fca.0.insert, i64 %iadd, 1 + ret { double, i64 } %.fca.1.insert +} + +declare { double, i64 } @__enzyme_fwddiff(...) + +define { double, i64 } @ad(double %x, double %dx) { + %m = call { double, i64 } (...) @__enzyme_fwddiff({ double, i64 } (double, i64)* @julia_logabsgamma_3264_inner.1, double %x, double %dx, i64 1) + ret { double, i64 } %m +} + +; CHECK: define internal { double, i64 } @fwddiffejulia_logabsgamma_3264_inner.1(double %x, double %"x'", i64 %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %iadd = add i64 %z, 1 +; CHECK-NEXT: %".fca.0.insert'ipiv" = insertvalue { double, i64 } zeroinitializer, double %"x'", 0 +; CHECK-NEXT: %".fca.1.insert'ipiv" = insertvalue { double, i64 } %".fca.0.insert'ipiv", i64 %iadd, 1 +; CHECK-NEXT: ret { double, i64 } %".fca.1.insert'ipiv" +; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll b/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll new file mode 100644 index 000000000000..280d20db590e --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/logabsgamma.ll @@ -0,0 +1,28 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define { double, i64 } @tester(double %x) { +entry: + %a = call { double, i64 } @logabsgamma(double %x) + ret { double, i64 } %a +} + +define { double, i64 } @test_derivative(double %x, double %dx) { +entry: + %0 = tail call { double, i64 } (...) @__enzyme_fwddiff({ double, i64 } (double)* nonnull @tester, double %x, double %dx) + ret { double, i64 } %0 +} + +declare { double, i64 } @logabsgamma(double) + +; Function Attrs: nounwind +declare { double, i64 } @__enzyme_fwddiff(...) + +; CHECK: define internal { double, i64 } @fwddiffetester(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast double @digamma(double %x) +; CHECK-NEXT: %1 = fmul fast double %0, %"x'" +; CHECK-NEXT: %2 = insertvalue { double, i64 } undef, double %1, 0 +; CHECK-NEXT: ret { double, i64 } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/modf.ll b/enzyme/test/Enzyme/ForwardMode/modf.ll new file mode 100644 index 000000000000..d36c4d659315 --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/modf.ll @@ -0,0 +1,120 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +declare double @__enzyme_fwddiff(i8*, ...) +declare float @__enzyme_fwddifff(i8*, ...) +declare x86_fp80 @__enzyme_fwddiffl(i8*, ...) + +; double +declare double @modf(double, double*) +define double @testint(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + %ret = load double, double* %integral_part, align 8 + ret double %ret +} +define double @testfrac(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + ret double %fractional_part +} + +define double @dtestint(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @testint to i8*), double %x, double %dx) + ret double %call +} +define double @dtestfrac(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @testfrac to i8*), double %x, double %dx) + ret double %call +} + +; float +declare float @modff(float, float*) +define float @testintf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + %ret = load float, float* %integral_part, align 4 + ret float %ret +} +define float @testfracf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + ret float %fractional_part +} + +define float @dtestintf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_fwddifff(i8* bitcast (float (float)* @testintf to i8*), float %x, float %dx) + ret float %call +} +define float @dtestfracf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_fwddifff(i8* bitcast (float (float)* @testfracf to i8*), float %x, float %dx) + ret float %call +} + +; x86_fp80 +declare x86_fp80 @modfl(x86_fp80, x86_fp80*) +define x86_fp80 @testintl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + %ret = load x86_fp80, x86_fp80* %integral_part, align 8 + ret x86_fp80 %ret +} +define x86_fp80 @testfracl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + ret x86_fp80 %fractional_part +} + +define x86_fp80 @dtestintl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testintl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} +define x86_fp80 @dtestfracl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_fwddiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testfracl to i8*), x86_fp80 %x, x86_fp80 %dx) + ret x86_fp80 %call +} + +; tests + +; CHECK: define internal double @fwddiffetestint(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double 0.000000e+00 +; CHECK-NEXT: } + +; CHECK: define internal double @fwddiffetestfrac(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret double %"x'" +; CHECK-NEXT: } + +; CHECK: define internal float @fwddiffetestintf(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float 0.000000e+00 +; CHECK-NEXT: } + +; CHECK: define internal float @fwddiffetestfracf(float %x, float %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float %"x'" +; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestintl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret x86_fp80 0xK00000000000000000000 +; CHECK-NEXT: } + +; CHECK: define internal x86_fp80 @fwddiffetestfracl(x86_fp80 %x, x86_fp80 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: ret x86_fp80 %"x'" +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll index 4722ef548bd8..716cba079e61 100644 --- a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll +++ b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erf.ll @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn) ; CHECK-NEXT: entry: +; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 +; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 +; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a17]] ; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0 ; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1 ; CHECK-DAG: %[[a2:.+]] = fmul fast double %[[a1]], %[[a1]] @@ -36,16 +39,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK-NEXT: %[[a13:.+]] = fmul fast double %[[a9]], %[[a12]] ; CHECK-NEXT: %[[a14:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[a11]] ; CHECK-NEXT: %[[a15:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[a13]] -; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 -; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 ; CHECK-DAG: %[[a19:.+]] = fmul fast double %[[a16]], %[[a14]] -; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[a17]], %[[a15]] +; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[conj]], %[[a15]] ; CHECK-NEXT: %[[a20:.+]] = fsub fast double %[[a19]], %[[a18]] ; CHECK-DAG: %[[a22:.+]] = fmul fast double %[[a16]], %[[a15]] -; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[a17]] +; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[conj]] ; CHECK-NEXT: %[[a23:.+]] = fadd fast double %[[a22]], %[[a21]] +; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a23]] ; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[a20]], 0 -; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[a23]], 1 +; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1 ; CHECK-NEXT: %[[a24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0 ; CHECK-NEXT: ret { { double, double } } %[[a24]] ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll index 04c8fee4de92..dabd2674b7ba 100644 --- a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll +++ b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfc.ll @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn) ; CHECK-NEXT: entry: +; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 +; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 +; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a17]] ; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0 ; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1 ; CHECK-DAG: %[[a2:.+]] = fmul fast double %[[a1]], %[[a1]] @@ -36,16 +39,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK-NEXT: %[[a13:.+]] = fmul fast double %[[a9]], %[[a12]] ; CHECK-NEXT: %[[a14:.+]] = fmul fast double 0xBFF20DD750429B6D, %[[a11]] ; CHECK-NEXT: %[[a15:.+]] = fmul fast double 0xBFF20DD750429B6D, %[[a13]] -; CHECK-NEXT: %[[a16:.+]] = extractvalue { double, double } %differeturn, 0 -; CHECK-NEXT: %[[a17:.+]] = extractvalue { double, double } %differeturn, 1 ; CHECK-DAG: %[[a19:.+]] = fmul fast double %[[a16]], %[[a14]] -; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[a17]], %[[a15]] +; CHECK-DAG: %[[a18:.+]] = fmul fast double %[[conj]], %[[a15]] ; CHECK-NEXT: %[[a20:.+]] = fsub fast double %[[a19]], %[[a18]] ; CHECK-DAG: %[[a22:.+]] = fmul fast double %[[a16]], %[[a15]] -; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[a17]] +; CHECK-DAG: %[[a21:.+]] = fmul fast double %[[a14]], %[[conj]] ; CHECK-NEXT: %[[a23:.+]] = fadd fast double %[[a22]], %[[a21]] +; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[a23]] ; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[a20]], 0 -; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[a23]], 1 +; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1 ; CHECK-NEXT: %[[a24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0 ; CHECK-NEXT: ret { { double, double } } %[[a24]] ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll index 40932ffb5cf2..c5e5cfba0a90 100644 --- a/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll +++ b/enzyme/test/Enzyme/ReverseMode/Faddeeva_erfi.ll @@ -20,6 +20,9 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK: define internal { { double, double } } @diffetester({ double, double } %in, { double, double } %differeturn) ; CHECK-NEXT: entry: +; CHECK-NEXT: %[[i16:.+]] = extractvalue { double, double } %differeturn, 0 +; CHECK-NEXT: %[[i17:.+]] = extractvalue { double, double } %differeturn, 1 +; CHECK-NEXT: %[[conj:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[i17]] ; CHECK-NEXT: %[[a0:.+]] = extractvalue { double, double } %in, 0 ; CHECK-NEXT: %[[a1:.+]] = extractvalue { double, double } %in, 1 ; CHECK-NEXT: %[[a3:.+]] = fmul fast double %[[a0]], %[[a0]] @@ -34,16 +37,15 @@ declare { double, double } @__enzyme_autodiff({ double, double } ({ double, doub ; CHECK-NEXT: %[[i13:.+]] = fmul fast double %[[i9]], %[[i12]] ; CHECK-NEXT: %[[i14:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[i11]] ; CHECK-NEXT: %[[i15:.+]] = fmul fast double 0x3FF20DD750429B6D, %[[i13]] -; CHECK-NEXT: %[[i16:.+]] = extractvalue { double, double } %differeturn, 0 -; CHECK-NEXT: %[[i17:.+]] = extractvalue { double, double } %differeturn, 1 ; CHECK-NEXT: %[[i19:.+]] = fmul fast double %[[i16]], %[[i14]] -; CHECK-NEXT: %[[i18:.+]] = fmul fast double %[[i17]], %[[i15]] +; CHECK-NEXT: %[[i18:.+]] = fmul fast double %[[conj]], %[[i15]] ; CHECK-NEXT: %[[i20:.+]] = fsub fast double %[[i19]], %[[i18]] ; CHECK-NEXT: %[[i22:.+]] = fmul fast double %[[i16]], %[[i15]] -; CHECK-NEXT: %[[i21:.+]] = fmul fast double %[[i14]], %[[i17]] +; CHECK-NEXT: %[[i21:.+]] = fmul fast double %[[i14]], %[[conj]] ; CHECK-NEXT: %[[i23:.+]] = fadd fast double %[[i22]], %[[i21]] +; CHECK-NEXT: %[[conj2:.+]] = {{(fsub fast double \-0.000000e\+00,|fneg fast double)}} %[[i23]] ; CHECK-NEXT: %[[insert5:.+]] = insertvalue { double, double } {{(undef|poison)}}, double %[[i20]], 0 -; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[i23]], 1 +; CHECK-NEXT: %[[insert8:.+]] = insertvalue { double, double } %[[insert5]], double %[[conj2]], 1 ; CHECK-NEXT: %[[i24:.+]] = insertvalue { { double, double } } undef, { double, double } %[[insert8]], 0 ; CHECK-NEXT: ret { { double, double } } %[[i24]] ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/acosh.ll b/enzyme/test/Enzyme/ReverseMode/acosh.ll new file mode 100644 index 000000000000..b6ff83ea4f5f --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/acosh.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @acosh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @acosh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fsub fast double %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast double @llvm.sqrt.f64(double %2) +; CHECK-NEXT: %4 = fdiv fast double %0, %3 +; CHECK-NEXT: %5 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fadd fast double %5, %4 +; CHECK-NEXT: store double %6, double* %"x'de", align 8 +; CHECK-NEXT: %7 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %8 = insertvalue { double } undef, double %7, 0 +; CHECK-NEXT: ret { double } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinh.ll b/enzyme/test/Enzyme/ReverseMode/asinh.ll new file mode 100644 index 000000000000..3c34fabc25bf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinh.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @asinh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @asinh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fadd fast double %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast double @llvm.sqrt.f64(double %2) +; CHECK-NEXT: %4 = fdiv fast double %0, %3 +; CHECK-NEXT: %5 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fadd fast double %5, %4 +; CHECK-NEXT: store double %6, double* %"x'de", align 8 +; CHECK-NEXT: %7 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %8 = insertvalue { double } undef, double %7, 0 +; CHECK-NEXT: ret { double } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinhf.ll b/enzyme/test/Enzyme/ReverseMode/asinhf.ll new file mode 100644 index 000000000000..c113c111abc1 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinhf.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @asinhf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* nonnull @tester, float %x) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @asinhf(float) + +; Function Attrs: nounwind +declare float @__enzyme_autodiff(float (float)*, ...) + +; CHECK: define internal { float } @diffetester(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"'de", align 4 +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store float %differeturn, float* %"'de", align 4 +; CHECK-NEXT: %0 = load float, float* %"'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"'de", align 4 +; CHECK-NEXT: %1 = fmul fast float %x, %x +; CHECK-NEXT: %2 = fadd fast float %1, 1.000000e+00 +; CHECK-NEXT: %3 = call fast float @llvm.sqrt.f32(float %2) +; CHECK-NEXT: %4 = fdiv fast float %0, %3 +; CHECK-NEXT: %5 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %6 = fadd fast float %5, %4 +; CHECK-NEXT: store float %6, float* %"x'de", align 4 +; CHECK-NEXT: %7 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %8 = insertvalue { float } undef, float %7, 0 +; CHECK-NEXT: ret { float } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/asinhl.ll b/enzyme/test/Enzyme/ReverseMode/asinhl.ll new file mode 100644 index 000000000000..759d05bf3236 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/asinhl.ll @@ -0,0 +1,46 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @asinhl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_autodiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @asinhl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_autodiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal { x86_fp80 } @diffetester(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store x86_fp80 %differeturn, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"'de", align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"'de", align 16 +; CHECK-NEXT: %1 = fmul fast x86_fp80 %x, %x +; CHECK-NEXT: %2 = fadd fast x86_fp80 %1, 0xK3FFF8000000000000000 +; CHECK-NEXT: %3 = call fast x86_fp80 @llvm.sqrt.f80(x86_fp80 %2) +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %0, %3 +; CHECK-NEXT: %5 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %6 = fadd fast x86_fp80 %5, %4 +; CHECK-NEXT: store x86_fp80 %6, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %7 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %8 = insertvalue { x86_fp80 } undef, x86_fp80 %7, 0 +; CHECK-NEXT: ret { x86_fp80 } %8 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/atanh.ll b/enzyme/test/Enzyme/ReverseMode/atanh.ll new file mode 100644 index 000000000000..708c4da03bf3 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/atanh.ll @@ -0,0 +1,45 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %0 = tail call fast double @atanh(double %x) + ret double %0 +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +; Function Attrs: nounwind readnone speculatable +declare double @atanh(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fmul fast double %x, %x +; CHECK-NEXT: %2 = fsub fast double 1.000000e+00, %1 +; CHECK-NEXT: %3 = fdiv fast double %0, %2 +; CHECK-NEXT: %4 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %5 = fadd fast double %4, %3 +; CHECK-NEXT: store double %5, double* %"x'de", align 8 +; CHECK-NEXT: %6 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0 +; CHECK-NEXT: ret { double } %7 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll index e84c12e14701..633feb70203a 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c.ll @@ -212,12 +212,12 @@ entry: ; CHECK-NEXT: %[[r65:.+]] = icmp eq i8 %loaded.trans4, 78 ; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans4, 110 ; CHECK-NEXT: %[[r67:.+]] = or i1 %[[r66]], %[[r65]] -; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r67]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r68:.+]] = select i1 %[[r67]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans5 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r69:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[r71:.+]] = or i1 %[[r70]], %[[r69]] -; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r72:.+]] = select i1 %[[r71]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r73:.+]] = icmp eq i8 %ld.row.trans6, 110 ; CHECK-NEXT: %[[r74:.+]] = icmp eq i8 %ld.row.trans6, 78 @@ -251,12 +251,12 @@ entry: ; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %[[loaded_trans10]], 78 ; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %[[loaded_trans10]], 110 ; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] -; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans11:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %[[loaded_trans11]], 78 ; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %[[loaded_trans11]], 110 ; CHECK-NEXT: %[[r93:.+]] = or i1 %[[r92]], %[[r91]] -; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r94:.+]] = select i1 %[[r93]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans12:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r95:.+]] = icmp eq i8 %[[ld_row_trans12]], 110 ; CHECK-NEXT: %[[r96:.+]] = icmp eq i8 %[[ld_row_trans12]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll index 0918ca43f83f..78fc4959f21e 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_lacpy_runtime_act.ll @@ -109,7 +109,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i3]]) ; CHECK-NEXT: %[[i10:.+]] = bitcast i8* %m_p to i64* ; CHECK-NEXT: %[[i11:.+]] = load i64, i64* %[[i10]] ; CHECK-NEXT: %[[i12:.+]] = bitcast i8* %n_p to i64* @@ -119,7 +119,7 @@ entry: ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i64 %mallocsize1) ; CHECK-NEXT: %cache.C = bitcast i8* %malloccall2 to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage3 -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %n_p) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage3, i8* %m_p, i8* %n_p, i8* %C, i8* %ldc_p, double* %cache.C, i8* %m_p) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* @@ -152,7 +152,7 @@ entry: ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans7, 78 ; CHECK-DAG: %[[i42:.+]] = icmp eq i8 %loaded.trans7, 110 ; CHECK-NEXT: %[[i43:.+]] = or i1 %[[i42]], %[[i41]] -; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i43]], i8* %m_p, i8* %k_p ; CHECK-NEXT: store double 0.000000e+00, double* %byref.constant.fp.0.0 ; CHECK-NEXT: %fpcast.constant.fp.0.0 = bitcast double* %byref.constant.fp.0.0 to i8* ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %fpcast.constant.fp.1.0, i8* %[[matA]], i8* %[[i44]], i8* %B, i8* %ldb_p, i8* %fpcast.constant.fp.0.0, i8* %[[i21]], i8* %m_p, i64 1, i64 1) @@ -268,12 +268,12 @@ entry: ; CHECK-NEXT: %[[r83:.+]] = icmp eq i8 %[[loaded_trans14]], 78 ; CHECK-NEXT: %[[r84:.+]] = icmp eq i8 %[[loaded_trans14]], 110 ; CHECK-NEXT: %[[r85:.+]] = or i1 %[[r84]], %[[r83]] -; CHECK-NEXT: %[[r86:.+]] = select i1 %[[r85]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r86:.+]] = select i1 %[[r85]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans15:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[r87:.+]] = icmp eq i8 %[[loaded_trans15]], 78 ; CHECK-NEXT: %[[r88:.+]] = icmp eq i8 %[[loaded_trans15]], 110 ; CHECK-NEXT: %[[r89:.+]] = or i1 %[[r88]], %[[r87]] -; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r90:.+]] = select i1 %[[r89]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans16:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[r91:.+]] = icmp eq i8 %[[ld_row_trans16]], 110 ; CHECK-NEXT: %[[r92:.+]] = icmp eq i8 %[[ld_row_trans16]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll index c66d4ff88d63..606ecb1e1c5b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_loop.ll @@ -271,12 +271,12 @@ entry: ; CHECK-NEXT: %[[r62:.+]] = icmp eq i8 %loaded.trans30, 78 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %loaded.trans30, 110 ; CHECK-NEXT: %[[r64:.+]] = or i1 %[[r63]], %[[r62]] -; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %cast.k, i8* %[[r37]] +; CHECK-NEXT: %[[r65:.+]] = select i1 %[[r64]], i8* %[[r37]], i8* %cast.k ; CHECK-NEXT: %loaded.trans31 = load i8, i8* %byref.transa, align 1 ; CHECK-NEXT: %[[r66:.+]] = icmp eq i8 %loaded.trans31, 78 ; CHECK-NEXT: %[[r67:.+]] = icmp eq i8 %loaded.trans31, 110 ; CHECK-NEXT: %[[r68:.+]] = or i1 %[[r67]], %[[r66]] -; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r68]], i8* %cast.k, i8* %[[r37]] +; CHECK-NEXT: %[[r69:.+]] = select i1 %[[r68]], i8* %[[r37]], i8* %cast.k ; CHECK-NEXT: %ld.row.trans32 = load i8, i8* %byref.transb, align 1 ; CHECK-NEXT: %[[r70:.+]] = icmp eq i8 %ld.row.trans32, 110 ; CHECK-NEXT: %[[r71:.+]] = icmp eq i8 %ld.row.trans32, 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll index 0e30a9fdc7da..48589c0dd732 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split.ll @@ -265,12 +265,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[loaded_trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[loaded_trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[ld_row_trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[ld_row_trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll index f58a4b2e209a..132e5c14ba4d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_lacpy.ll @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %21) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) ; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] @@ -239,12 +239,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll index 2a7bbed40c24..c3a424a956a3 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_split_transpose_lacpy.ll @@ -117,7 +117,7 @@ entry: ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall10 to double* ; CHECK-NEXT: store double* %cache.A, double** %0 ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %21) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %20, i8* %21, i8* %A, i8* %lda_p, double* %cache.A, i8* %20) ; CHECK-NEXT: call void @dgemm_64_(i8* %malloccall, i8* %malloccall1, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %[[ret:.+]] = load double*, double** %0 ; CHECK-NEXT: ret double* %[[ret]] @@ -237,12 +237,12 @@ entry: ; CHECK-NEXT: %[[r55:.+]] = icmp eq i8 %loaded.trans, 78 ; CHECK-NEXT: %[[r56:.+]] = icmp eq i8 %loaded.trans, 110 ; CHECK-NEXT: %[[r57:.+]] = or i1 %[[r56]], %[[r55]] -; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r58:.+]] = select i1 %[[r57]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans5:.+]] = load i8, i8* %malloccall, align 1 ; CHECK-NEXT: %[[r59:.+]] = icmp eq i8 %[[loaded_trans5]], 78 ; CHECK-NEXT: %[[r60:.+]] = icmp eq i8 %[[loaded_trans5]], 110 ; CHECK-NEXT: %[[r61:.+]] = or i1 %[[r60]], %[[r59]] -; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[r62:.+]] = select i1 %[[r61]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans6:.+]] = load i8, i8* %malloccall1, align 1 ; CHECK-NEXT: %[[r63:.+]] = icmp eq i8 %[[ld_row_trans6]], 110 ; CHECK-NEXT: %[[r64:.+]] = icmp eq i8 %[[ld_row_trans6]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll index 964b53b1b925..ffc9f795f395 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_c_transpose_lacpy.ll @@ -105,7 +105,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %[[byrefgarbage]] -; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage]], i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage]], i8* %[[i3]], i8* %[[i4]], i8* %A, i8* %lda_p, double* %cache.A, i8* %[[i3]]) ; CHECK-NEXT: %loaded.trans1 = load i8, i8* %transb ; CHECK-DAG: %[[i10:.+]] = icmp eq i8 %loaded.trans1, 78 ; CHECK-DAG: %[[i11:.+]] = icmp eq i8 %loaded.trans1, 110 @@ -121,7 +121,7 @@ entry: ; CHECK-NEXT: %[[malloccall2:.+]] = tail call noalias nonnull i8* @malloc(i64 %[[mallocsize1]]) ; CHECK-NEXT: %cache.B = bitcast i8* %[[malloccall2]] to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage4 -; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %[[i13]], i8* %[[i14]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[i14]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %[[byrefgarbage2]], i8* %[[i13]], i8* %[[i14]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[i13]]) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %"ptr'ipc" = bitcast i8* %"A'" to double* ; CHECK-NEXT: %ptr = bitcast i8* %A to double* @@ -156,12 +156,12 @@ entry: ; CHECK-DAG: %[[i40:.+]] = icmp eq i8 %loaded.trans5, 78 ; CHECK-DAG: %[[i41:.+]] = icmp eq i8 %loaded.trans5, 110 ; CHECK-NEXT: %[[i42:.+]] = or i1 %[[i41]], %[[i40]] -; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i42]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans6 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a49:.+]] = icmp eq i8 %loaded.trans6, 78 ; CHECK-NEXT: %[[a50:.+]] = icmp eq i8 %loaded.trans6, 110 ; CHECK-NEXT: %[[a51:.+]] = or i1 %[[a50]], %[[a49]] -; CHECK-NEXT: %[[a52:.+]] = select i1 %[[a51]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[a52:.+]] = select i1 %[[a51]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans7 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a53:.+]] = icmp eq i8 %ld.row.trans7, 110 ; CHECK-NEXT: %[[a54:.+]] = icmp eq i8 %ld.row.trans7, 78 @@ -195,12 +195,12 @@ entry: ; CHECK-DAG: %[[i54:.+]] = icmp eq i8 %[[cachedtrans2]], 78 ; CHECK-DAG: %[[i55:.+]] = icmp eq i8 %[[cachedtrans2]], 110 ; CHECK-NEXT: %[[i56:.+]] = or i1 %[[i55]], %[[i54]] -; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i56]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[loaded_trans12:.+]] = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a71:.+]] = icmp eq i8 %[[loaded_trans12]], 78 ; CHECK-NEXT: %[[a72:.+]] = icmp eq i8 %[[loaded_trans12]], 110 ; CHECK-NEXT: %[[a73:.+]] = or i1 %[[a72]], %[[a71]] -; CHECK-NEXT: %[[a74:.+]] = select i1 %[[a73]], i8* %k_p, i8* %m_p +; CHECK-NEXT: %[[a74:.+]] = select i1 %[[a73]], i8* %m_p, i8* %k_p ; CHECK-NEXT: %[[ld_row_trans13:.+]] = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a75:.+]] = icmp eq i8 %[[ld_row_trans13]], 110 ; CHECK-NEXT: %[[a76:.+]] = icmp eq i8 %[[ld_row_trans13]], 78 diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll index a4c00d61bbc4..0c50fa797b3b 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemm_f_change_ld.ll @@ -103,7 +103,7 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i64 %mallocsize) ; CHECK-NEXT: %cache.B = bitcast i8* %malloccall to double* ; CHECK-NEXT: store i8 0, i8* %byref.copy.garbage -; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[z3]], i8* %[[z4]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[z4]]) +; CHECK-NEXT: call void @dlacpy_64_(i8* %byref.copy.garbage, i8* %[[z3]], i8* %[[z4]], i8* %B, i8* %ldb_p, double* %cache.B, i8* %[[z3]]) ; CHECK-NEXT: call void @dgemm_64_(i8* %transa, i8* %transb, i8* %m_p, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %B, i8* %ldb_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1) ; CHECK-NEXT: %ptr = bitcast i8* %B to double* ; CHECK-NEXT: store double 0.000000e+00, double* %ptr, align 8 @@ -135,12 +135,12 @@ entry: ; CHECK-DAG: %[[r16:.+]] = icmp eq i8 %loaded.trans1, 78 ; CHECK-DAG: %[[r17:.+]] = icmp eq i8 %loaded.trans1, 110 ; CHECK-NEXT: %[[r18:.+]] = or i1 %[[r17]], %[[r16]] -; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[r19:.+]] = select i1 %[[r18]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %loaded.trans2 = load i8, i8* %transb, align 1 ; CHECK-NEXT: %[[a38:.+]] = icmp eq i8 %loaded.trans2, 78 ; CHECK-NEXT: %[[a39:.+]] = icmp eq i8 %loaded.trans2, 110 ; CHECK-NEXT: %[[a40:.+]] = or i1 %[[a39]], %[[a38]] -; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[a41:.+]] = select i1 %[[a40]], i8* %k_p, i8* %n_p ; CHECK-NEXT: %ld.row.trans3 = load i8, i8* %transa, align 1 ; CHECK-NEXT: %[[a42:.+]] = icmp eq i8 %ld.row.trans3, 110 ; CHECK-NEXT: %[[a43:.+]] = icmp eq i8 %ld.row.trans3, 78 diff --git a/enzyme/test/Enzyme/ReverseMode/coshf.ll b/enzyme/test/Enzyme/ReverseMode/coshf.ll new file mode 100644 index 000000000000..b028239a4ecf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/coshf.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define float @tester(float %x) { +entry: + %0 = tail call fast float @coshf(float %x) + ret float %0 +} + +define float @test_derivative(float %x) { +entry: + %0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* nonnull @tester, float %x) + ret float %0 +} + +; Function Attrs: nounwind readnone speculatable +declare float @coshf(float) + +; Function Attrs: nounwind +declare float @__enzyme_autodiff(float (float)*, ...) + +; CHECK: define internal { float } @diffetester(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast float @sinhf(float %x) +; CHECK-NEXT: %1 = fmul fast float %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { float } undef, float %1, 0 +; CHECK-NEXT: ret { float } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/coshl.ll b/enzyme/test/Enzyme/ReverseMode/coshl.ll new file mode 100644 index 000000000000..dd8af143ff48 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/coshl.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define x86_fp80 @tester(x86_fp80 %x) { +entry: + %0 = tail call fast x86_fp80 @coshl(x86_fp80 %x) + ret x86_fp80 %0 +} + +define x86_fp80 @test_derivative(x86_fp80 %x) { +entry: + %0 = tail call x86_fp80 (x86_fp80 (x86_fp80)*, ...) @__enzyme_autodiff(x86_fp80 (x86_fp80)* nonnull @tester, x86_fp80 %x) + ret x86_fp80 %0 +} + +; Function Attrs: nounwind readnone speculatable +declare x86_fp80 @coshl(x86_fp80) + +; Function Attrs: nounwind +declare x86_fp80 @__enzyme_autodiff(x86_fp80 (x86_fp80)*, ...) + +; CHECK: define internal { x86_fp80 } @diffetester(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call fast x86_fp80 @sinhl(x86_fp80 %x) +; CHECK-NEXT: %1 = fmul fast x86_fp80 %differeturn, %0 +; CHECK-NEXT: %2 = insertvalue { x86_fp80 } undef, x86_fp80 %1, 0 +; CHECK-NEXT: ret { x86_fp80 } %2 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll b/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll index 195470291e2a..da82145bdbd8 100644 --- a/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll +++ b/enzyme/test/Enzyme/ReverseMode/custom-sret3.ll @@ -118,7 +118,7 @@ attributes #4 = { nounwind } !10 = !{!11, !11, i64 0} !11 = !{!"int", !5, i64 0} -; CHECK: define internal void @fixbyval_myblas_cdot_rev(%struct.complex* %arg0, %struct.complex* %arg1, %struct.complex* %arg2, %struct.complex* %arg3, i32 %arg4, i32 %arg5, %struct.complex %arg6, i8* %arg7) +; CHECK: define dso_local void @fixbyval_myblas_cdot_rev(%struct.complex* %arg0, %struct.complex* %arg1, %struct.complex* %arg2, %struct.complex* %arg3, i32 %arg4, i32 %arg5, %struct.complex %arg6, i8* %arg7) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = alloca %struct.complex ; CHECK-NEXT: store %struct.complex %arg6, %struct.complex* %0 diff --git a/enzyme/test/Enzyme/ReverseMode/fdim.ll b/enzyme/test/Enzyme/ReverseMode/fdim.ll new file mode 100644 index 000000000000..113532464547 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/fdim.ll @@ -0,0 +1,54 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s ; fi + +declare double @fdim(double, double) + +define double @tester(double %x, double %y) { +entry: + %0 = call double @fdim(double %x, double %y) + ret double %0 +} + +define double @test_derivative(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, double %x, double %y) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double, double)*, ...) + +; CHECK: define internal { double, double } @diffetester(double %x, double %y, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: %"y'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"y'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"'de", align 8 +; CHECK-NEXT: %1 = fcmp fast olt double %x, %y +; CHECK-NEXT: %2 = select fast i1 %1, double 0.000000e+00, double %0 +; CHECK-NEXT: %3 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %4 = fadd fast double %3, %0 +; CHECK-NEXT: %5 = select fast i1 %1, double %3, double %4 +; CHECK-NEXT: store double %5, double* %"x'de", align 8 +; CHECK-NEXT: %6 = fcmp fast olt double %x, %y +; CHECK-NEXT: %7 = fneg fast double %0 +; CHECK-NEXT: %8 = select fast i1 %6, double 0.000000e+00, double %7 +; CHECK-NEXT: %9 = load double, double* %"y'de", align 8 +; CHECK-NEXT: %10 = fadd fast double %9, %7 +; CHECK-NEXT: %11 = select fast i1 %6, double %9, double %10 +; CHECK-NEXT: store double %11, double* %"y'de", align 8 +; CHECK-NEXT: %12 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %13 = load double, double* %"y'de", align 8 +; CHECK-NEXT: %14 = insertvalue { double, double } undef, double %12, 0 +; CHECK-NEXT: %15 = insertvalue { double, double } %14, double %13, 1 +; CHECK-NEXT: ret { double, double } %15 +; CHECK-NEXT: } + diff --git a/enzyme/test/Enzyme/ReverseMode/frexp.ll b/enzyme/test/Enzyme/ReverseMode/frexp.ll index 2e74d6914105..fe134bdab00b 100644 --- a/enzyme/test/Enzyme/ReverseMode/frexp.ll +++ b/enzyme/test/Enzyme/ReverseMode/frexp.ll @@ -1,10 +1,11 @@ ; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -simplifycfg -instsimplify -adce -S | FileCheck %s; fi ; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg,instsimplify,adce)" -S | FileCheck %s -declare double @frexp(double, i32*) declare double @__enzyme_autodiff(i8*, ...) declare float @__enzyme_autodifff(i8*, ...) +declare x86_fp80 @__enzyme_autodiffl(i8*, ...) +declare double @frexp(double, i32*) define double @test(double %x) { entry: %exp = alloca i32, align 4 @@ -32,6 +33,20 @@ entry: ret float %call } +declare x86_fp80 @frexpl(x86_fp80, i32*) +define x86_fp80 @testl(x86_fp80 %x) { +entry: + %exp = alloca i32, align 4 + %call = call x86_fp80 @frexpl(x86_fp80 %x, i32* %exp) + ret x86_fp80 %call +} + +define x86_fp80 @dtestl(x86_fp80 %x) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} + ; CHECK: define internal { double } @diffetest(double %x, double %differeturn) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = bitcast double %x to i64 @@ -53,3 +68,14 @@ entry: ; CHECK-NEXT: %5 = insertvalue { float } {{(undef|poison)}}, float %4, 0 ; CHECK-NEXT: ret { float } %5 ; CHECK-NEXT: } + +; CHECK: define internal { x86_fp80 } @diffetestl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = bitcast x86_fp80 %x to i80 +; CHECK-NEXT: %1 = and i80 604453686435277732577280, %0 +; CHECK-NEXT: %2 = bitcast i80 %1 to x86_fp80 +; CHECK-NEXT: %3 = fmul fast x86_fp80 %2, 0xK40008000000000000000 +; CHECK-NEXT: %4 = fdiv fast x86_fp80 %differeturn, %3 +; CHECK-NEXT: %5 = insertvalue { x86_fp80 } {{(undef|poison)}}, x86_fp80 %4, 0 +; CHECK-NEXT: ret { x86_fp80 } %5 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll new file mode 100644 index 000000000000..7682abf13dd9 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/logabsgamma.ll @@ -0,0 +1,38 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s + +; Function Attrs: nounwind readnone uwtable +define double @tester(double %x) { +entry: + %a = call { double, i64 } @logabsgamma(double %x) + %b = extractvalue { double, i64 } %a, 0 + ret double %b +} + +define double @test_derivative(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x) + ret double %0 +} + +declare { double, i64 } @logabsgamma(double) + +; Function Attrs: nounwind +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal { double } @diffetester(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"a'de" = alloca { double, i64 }, align 8 +; CHECK-NEXT: store { double, i64 } zeroinitializer, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: %0 = getelementptr inbounds { double, i64 }, { double, i64 }* %"a'de", i32 0, i32 0 +; CHECK-NEXT: %1 = load double, double* %0, align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %differeturn +; CHECK-NEXT: store double %2, double* %0, align 8 +; CHECK-NEXT: %3 = load { double, i64 }, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: store { double, i64 } zeroinitializer, { double, i64 }* %"a'de", align 8 +; CHECK-NEXT: %4 = call fast double @digamma(double %x) +; CHECK-NEXT: %5 = extractvalue { double, i64 } %3, 0 +; CHECK-NEXT: %6 = fmul fast double %4, %5 +; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0 +; CHECK-NEXT: ret { double } %7 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll new file mode 100644 index 000000000000..dc0e9599fe22 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/mlirmincut.ll @@ -0,0 +1,107 @@ +; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %OPnewLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi + +declare void @__enzyme_autodiff0(...) local_unnamed_addr + +declare void @_mlir_memref_to_llvm_free(ptr) + +declare ptr @_mlir_memref_to_llvm_alloc(i64) + +define void @jit_compiled(ptr %a) { + tail call void (...) @__enzyme_autodiff0(ptr nonnull @f, metadata !"enzyme_const", ptr %a, ptr %a, ptr %a, i64 0, metadata !"enzyme_const", ptr %a, metadata !"enzyme_const", ptr %a, i64 0, metadata !"enzyme_const", i64 1, metadata !"enzyme_const", ptr %a, metadata !"enzyme_dupnoneed", ptr %a, ptr %a, i64 0) + ret void +} + +define void @f(ptr %arg, ptr %arg1, i64 %arg2, ptr %arg3, ptr %arg4, i64 %arg5, i64 %arg6, ptr nocapture readnone %arg7, ptr nocapture writeonly %arg8, i64 %arg9) { + %.idx = shl i64 %arg6, 3 + %i = tail call ptr @_mlir_memref_to_llvm_alloc(i64 %.idx) + %i10 = load double, ptr %arg4, align 8 + %i11 = fcmp ogt double %i10, 1.500000e+00 + %i12 = load double, ptr %arg1, align 8 + %i13 = fmul double %i12, 2.000000e+00 + %storemerge = select i1 %i11, double %i13, double %i12 + store double %storemerge, ptr %i, align 8 + %i14 = alloca { ptr, ptr, i64 }, align 8 + store ptr %arg, ptr %i14, align 8 + %.repack2 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 1 + store ptr %arg1, ptr %.repack2, align 8 + %.repack4 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 2 + store i64 %arg2, ptr %.repack4, align 8 + %i15 = alloca { ptr, ptr, i64 }, align 8 + store ptr %arg3, ptr %i15, align 8 + %.repack6 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 1 + store ptr %arg4, ptr %.repack6, align 8 + %.repack8 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 2 + store i64 %arg5, ptr %.repack8, align 8 + %i16 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8 + store ptr %i, ptr %i16, align 8 + %.repack10 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 1 + store ptr %i, ptr %.repack10, align 8 + %.repack12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 2 + store i64 0, ptr %.repack12, align 8 + %.repack14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 3 + store i64 %arg6, ptr %.repack14, align 8 + %.repack16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 4 + store i64 1, ptr %.repack16, align 8 + %i17 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) + %i18 = alloca { ptr, ptr, i64 }, align 8 + store ptr %i17, ptr %i18, align 8 + %.repack18 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 1 + store ptr %i17, ptr %.repack18, align 8 + %.repack20 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 2 + store i64 0, ptr %.repack20, align 8 + %i19 = load double, ptr %i17, align 8 + store double %i19, ptr %arg8, align 8 + ret void +} + +; CHECK: define internal void @diffef(ptr %arg, ptr %arg1, ptr %"arg1'", i64 %arg2, ptr %arg3, ptr %arg4, i64 %arg5, i64 %arg6, ptr nocapture readnone %arg7, ptr nocapture writeonly %arg8, ptr nocapture %"arg8'", i64 %arg9) +; CHECK-NEXT: invert: +; CHECK-NEXT: %.idx = shl i64 %arg6, 3 +; CHECK-NEXT: %i = tail call ptr @_mlir_memref_to_llvm_alloc(i64 %.idx) +; CHECK-NEXT: %i10 = load double, ptr %arg4, align 8 +; CHECK-NEXT: %i11 = fcmp ogt double %i10, 1.500000e+00 +; CHECK-NEXT: %i12 = load double, ptr %arg1, align 8 +; CHECK-NEXT: %i13 = fmul double %i12, 2.000000e+00 +; CHECK-NEXT: %storemerge = select i1 %i11, double %i13, double %i12 +; CHECK-NEXT: store double %storemerge, ptr %i, align 8 +; CHECK-NEXT: %i14 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %arg, ptr %i14, align 8 +; CHECK-NEXT: %.repack2 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 1 +; CHECK-NEXT: store ptr %arg1, ptr %.repack2, align 8 +; CHECK-NEXT: %.repack4 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i14, i64 0, i32 2 +; CHECK-NEXT: store i64 %arg2, ptr %.repack4, align 8 +; CHECK-NEXT: %i15 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %arg3, ptr %i15, align 8 +; CHECK-NEXT: %.repack6 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 1 +; CHECK-NEXT: store ptr %arg4, ptr %.repack6, align 8 +; CHECK-NEXT: %.repack8 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i15, i64 0, i32 2 +; CHECK-NEXT: store i64 %arg5, ptr %.repack8, align 8 +; CHECK-NEXT: %i16 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8 +; CHECK-NEXT: store ptr %i, ptr %i16, align 8 +; CHECK-NEXT: %.repack10 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 1 +; CHECK-NEXT: store ptr %i, ptr %.repack10, align 8 +; CHECK-NEXT: %.repack12 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 2 +; CHECK-NEXT: store i64 0, ptr %.repack12, align 8 +; CHECK-NEXT: %.repack14 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 3 +; CHECK-NEXT: store i64 %arg6, ptr %.repack14, align 8 +; CHECK-NEXT: %.repack16 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %i16, i64 0, i32 4 +; CHECK-NEXT: store i64 1, ptr %.repack16, align 8 +; CHECK-NEXT: %"i17'mi" = tail call noalias nonnull ptr @_mlir_memref_to_llvm_alloc(i64 8) +; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr nonnull dereferenceable(8) dereferenceable_or_null(8) %"i17'mi", i8 0, i64 8, i1 false) +; CHECK-NEXT: %i17 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) +; CHECK-NEXT: %i18 = alloca { ptr, ptr, i64 }, align 8 +; CHECK-NEXT: store ptr %i17, ptr %i18, align 8 +; CHECK-NEXT: %.repack18 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 1 +; CHECK-NEXT: store ptr %i17, ptr %.repack18, align 8 +; CHECK-NEXT: %.repack20 = getelementptr inbounds { ptr, ptr, i64 }, ptr %i18, i64 0, i32 2 +; CHECK-NEXT: store i64 0, ptr %.repack20, align 8 +; CHECK-NEXT: %0 = load double, ptr %"arg8'", align 8 +; CHECK-NEXT: store double 0.000000e+00, ptr %"arg8'", align 8 +; CHECK-NEXT: %1 = load double, ptr %"i17'mi", align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %0 +; CHECK-NEXT: store double %2, ptr %"i17'mi", align 8 +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr nonnull %"i17'mi") +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr %i17) +; CHECK-NEXT: call void @_mlir_memref_to_llvm_free(ptr %i) +; CHECK-NEXT: ret void +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ReverseMode/modf.ll b/enzyme/test/Enzyme/ReverseMode/modf.ll new file mode 100644 index 000000000000..5c601a39e340 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/modf.ll @@ -0,0 +1,189 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +declare double @__enzyme_autodiff(i8*, ...) +declare float @__enzyme_autodifff(i8*, ...) +declare x86_fp80 @__enzyme_autodiffl(i8*, ...) + +; double +declare double @modf(double, double*) +define double @testint(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + %ret = load double, double* %integral_part, align 8 + ret double %ret +} +define double @testfrac(double %x) { +entry: + %integral_part = alloca double, align 8 + %fractional_part = call double @modf(double %x, double* %integral_part) + ret double %fractional_part +} + +define double @dtestint(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double)* @testint to i8*), double %x) + ret double %call +} +define double @dtestfrac(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_autodiff(i8* bitcast (double (double)* @testfrac to i8*), double %x) + ret double %call +} + +; float +declare float @modff(float, float*) +define float @testintf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + %ret = load float, float* %integral_part, align 4 + ret float %ret +} +define float @testfracf(float %x) { +entry: + %integral_part = alloca float, align 4 + %fractional_part = call float @modff(float %x, float* %integral_part) + ret float %fractional_part +} + +define float @dtestintf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_autodifff(i8* bitcast (float (float)* @testintf to i8*), float %x) + ret float %call +} +define float @dtestfracf(float %x, float %dx) { +entry: + %call = call float (i8*, ...) @__enzyme_autodifff(i8* bitcast (float (float)* @testfracf to i8*), float %x) + ret float %call +} + +; x86_fp80 +declare x86_fp80 @modfl(x86_fp80, x86_fp80*) +define x86_fp80 @testintl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + %ret = load x86_fp80, x86_fp80* %integral_part, align 8 + ret x86_fp80 %ret +} +define x86_fp80 @testfracl(x86_fp80 %x) { +entry: + %integral_part = alloca x86_fp80, align 8 + %fractional_part = call x86_fp80 @modfl(x86_fp80 %x, x86_fp80* %integral_part) + ret x86_fp80 %fractional_part +} + +define x86_fp80 @dtestintl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testintl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} +define x86_fp80 @dtestfracl(x86_fp80 %x, x86_fp80 %dx) { +entry: + %call = call x86_fp80 (i8*, ...) @__enzyme_autodiffl(i8* bitcast (x86_fp80 (x86_fp80)* @testfracl to i8*), x86_fp80 %x) + ret x86_fp80 %call +} + +; double tests + +; CHECK: define internal { double } @diffetestint(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %1 = insertvalue { double } undef, double %0, 0 +; CHECK-NEXT: ret { double } %1 +; CHECK-NEXT: } + +; CHECK: define internal { double } @diffetestfrac(double %x, double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %"x'de" = alloca double, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store double %differeturn, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %0 = load double, double* %"fractional_part'de", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"fractional_part'de", align 8 +; CHECK-NEXT: %1 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %2 = fadd fast double %1, %0 +; CHECK-NEXT: store double %2, double* %"x'de", align 8 +; CHECK-NEXT: %3 = load double, double* %"x'de", align 8 +; CHECK-NEXT: %4 = insertvalue { double } undef, double %3, 0 +; CHECK-NEXT: ret { double } %4 +; CHECK-NEXT: } + +; float tests + +; CHECK: define internal { float } @diffetestintf(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %1 = insertvalue { float } undef, float %0, 0 +; CHECK-NEXT: ret { float } %1 +; CHECK-NEXT: } + +; CHECK: define internal { float } @diffetestfracf(float %x, float %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %"x'de" = alloca float, align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"x'de", align 4 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store float %differeturn, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %0 = load float, float* %"fractional_part'de", align 4 +; CHECK-NEXT: store float 0.000000e+00, float* %"fractional_part'de", align 4 +; CHECK-NEXT: %1 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %2 = fadd fast float %1, %0 +; CHECK-NEXT: store float %2, float* %"x'de", align 4 +; CHECK-NEXT: %3 = load float, float* %"x'de", align 4 +; CHECK-NEXT: %4 = insertvalue { float } undef, float %3, 0 +; CHECK-NEXT: ret { float } %4 +; CHECK-NEXT: } + +; x86_fp80 tests + +; CHECK: define internal { x86_fp80 } @diffetestintl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %1 = insertvalue { x86_fp80 } undef, x86_fp80 %0, 0 +; CHECK-NEXT: ret { x86_fp80 } %1 +; CHECK-NEXT: } + +; CHECK: define internal { x86_fp80 } @diffetestfracl(x86_fp80 %x, x86_fp80 %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"fractional_part'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %"x'de" = alloca x86_fp80, align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: store x86_fp80 %differeturn, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %0 = load x86_fp80, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: store x86_fp80 0xK00000000000000000000, x86_fp80* %"fractional_part'de", align 16 +; CHECK-NEXT: %1 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %2 = fadd fast x86_fp80 %1, %0 +; CHECK-NEXT: store x86_fp80 %2, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %3 = load x86_fp80, x86_fp80* %"x'de", align 16 +; CHECK-NEXT: %4 = insertvalue { x86_fp80 } undef, x86_fp80 %3, 0 +; CHECK-NEXT: ret { x86_fp80 } %4 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index e9c61ebd773e..68f0ef473a9b 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -6,22 +6,29 @@ define i1 @f(double %x, double %y) { ret i1 %res } -declare i1 (double, double)* @__enzyme_truncate_func(...) +declare i1 (double, double)* @__enzyme_truncate_mem_func(...) +declare i1 (double, double)* @__enzyme_truncate_op_func(...) define i1 @tester(double %x, double %y) { entry: - %ptr = call i1 (double, double)* (...) @__enzyme_truncate_func(i1 (double, double)* @f, i64 64, i64 32) + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_mem_func(i1 (double, double)* @f, i64 64, i64 32) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} +define i1 @tester_op(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_op_func(i1 (double, double)* @f, i64 64, i64 32) + %res = call i1 %ptr(double %x, double %y) + ret i1 %res +} +define i1 @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call i1 (double, double)* (...) @__enzyme_truncate_op_func(i1 (double, double)* @f, i64 64, i64 3, i64 7) %res = call i1 %ptr(double %x, double %y) ret i1 %res } -; CHECK: define i1 @tester(double %x, double %y) { -; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call i1 @trunc_64_32f(double %x, double %y) -; CHECK-NEXT: ret i1 %res -; CHECK-NEXT: } - -; CHECK: define internal i1 @trunc_64_32f(double %x, double %y) { +; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -31,4 +38,9 @@ entry: ; CHECK-DAG: %5 = load float, float* %4, align 4 ; CHECK-DAG: %res = fcmp olt float %3, %5 ; CHECK-DAG: ret i1 %res -; CHECK-NEXT:} + +; CHECK: define internal i1 @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { +; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %res = fcmp olt float %enzyme_trunc, %enzyme_trunc1 +; CHECK-DAG: ret i1 %res diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index da4457492ce2..2299c9fb1ab3 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -1,11 +1,13 @@ ; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi ; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi +declare double @pow(double %Val, double %Power) declare double @llvm.pow.f64(double %Val, double %Power) declare double @llvm.powi.f64.i16(double %Val, i16 %power) declare void @llvm.nvvm.barrier0() define double @f(double %x, double %y) { + %res0 = call double @pow(double %x, double %y) %res1 = call double @llvm.pow.f64(double %x, double %y) %res2 = call double @llvm.powi.f64.i16(double %x, i16 2) %res = fadd double %res1, %res2 @@ -13,50 +15,102 @@ define double @f(double %x, double %y) { ret double %res } -declare double (double, double)* @__enzyme_truncate_func(...) +declare double (double, double)* @__enzyme_truncate_mem_func(...) +declare double (double, double)* @__enzyme_truncate_op_func(...) define double @tester(double %x, double %y) { entry: - %ptr = call double (double, double)* (...) @__enzyme_truncate_func(double (double, double)* @f, i64 64, i64 32) + %ptr = call double (double, double)* (...) @__enzyme_truncate_mem_func(double (double, double)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y) ret double %res } +define double @tester_op(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y) + ret double %res +} +define double @tester_op_mpfr(double %x, double %y) { +entry: + %ptr = call double (double, double)* (...) @__enzyme_truncate_op_func(double (double, double)* @f, i64 64, i64 3, i64 7) + %res = call double %ptr(double %x, double %y) + ret double %res +} + +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { +; CHECK-DAG: %1 = alloca double, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %2 = bitcast double* %1 to float* +; CHECK-DAG: %3 = load float, float* %2, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %4 = bitcast double* %1 to float* +; CHECK-DAG: %5 = load float, float* %4, align 4 +; CHECK-DAG: %res01 = call float @llvm.pow.f32(float %3, float %5) +; CHECK-DAG: %6 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %6, align 4 +; CHECK-DAG: %7 = bitcast double* %1 to float* +; CHECK-DAG: store float %res01, float* %7, align 4 +; CHECK-DAG: %8 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %9 = bitcast double* %1 to float* +; CHECK-DAG: %10 = load float, float* %9, align 4 +; CHECK-DAG: store double %y, double* %1, align 8 +; CHECK-DAG: %11 = bitcast double* %1 to float* +; CHECK-DAG: %12 = load float, float* %11, align 4 +; CHECK-DAG: %res12 = call float @llvm.pow.f32(float %10, float %12) +; CHECK-DAG: %13 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %13, align 4 +; CHECK-DAG: %14 = bitcast double* %1 to float* +; CHECK-DAG: store float %res12, float* %14, align 4 +; CHECK-DAG: %15 = load double, double* %1, align 8 +; CHECK-DAG: store double %x, double* %1, align 8 +; CHECK-DAG: %16 = bitcast double* %1 to float* +; CHECK-DAG: %17 = load float, float* %16, align 4 +; CHECK-DAG: %res23 = call float @llvm.powi.f32.i16(float %17, i16 2) +; CHECK-DAG: %18 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %18, align 4 +; CHECK-DAG: %19 = bitcast double* %1 to float* +; CHECK-DAG: store float %res23, float* %19, align 4 +; CHECK-DAG: %20 = load double, double* %1, align 8 +; CHECK-DAG: store double %15, double* %1, align 8 +; CHECK-DAG: %21 = bitcast double* %1 to float* +; CHECK-DAG: %22 = load float, float* %21, align 4 +; CHECK-DAG: store double %20, double* %1, align 8 +; CHECK-DAG: %23 = bitcast double* %1 to float* +; CHECK-DAG: %24 = load float, float* %23, align 4 +; CHECK-DAG: %res = fadd float %22, %24 +; CHECK-DAG: %25 = bitcast double* %1 to i64* +; CHECK-DAG: store i64 0, i64* %25, align 4 +; CHECK-DAG: %26 = bitcast double* %1 to float* +; CHECK-DAG: store float %res, float* %26, align 4 +; CHECK-DAG: %27 = load double, double* %1, align 8 +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %27 + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { +; CHECK-DAG: %enzyme_trunc = fptrunc double %x to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %res02 = call float @llvm.pow.f32(float %enzyme_trunc, float %enzyme_trunc1) +; CHECK-DAG: %enzyme_exp = fpext float %res02 to double +; CHECK-DAG: %enzyme_trunc3 = fptrunc double %x to float +; CHECK-DAG: %enzyme_trunc4 = fptrunc double %y to float +; CHECK-DAG: %res15 = call float @llvm.pow.f32(float %enzyme_trunc3, float %enzyme_trunc4) +; CHECK-DAG: %enzyme_exp6 = fpext float %res15 to double +; CHECK-DAG: %enzyme_trunc7 = fptrunc double %x to float +; CHECK-DAG: %res28 = call float @llvm.powi.f32.i16(float %enzyme_trunc7, i16 2) +; CHECK-DAG: %enzyme_exp9 = fpext float %res28 to double +; CHECK-DAG: %enzyme_trunc10 = fptrunc double %enzyme_exp6 to float +; CHECK-DAG: %enzyme_trunc11 = fptrunc double %enzyme_exp9 to float +; CHECK-DAG: %res = fadd float %enzyme_trunc10, %enzyme_trunc11 +; CHECK-DAG: %enzyme_exp12 = fpext float %res to double +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %enzyme_exp12 -; CHECK: define internal double @trunc_64_32f(double %x, double %y) { -; CHECK-NEXT: %1 = alloca double, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %2 = bitcast double* %1 to float* -; CHECK-NEXT: %3 = load float, float* %2, align 4 -; CHECK-NEXT: store double %y, double* %1, align 8 -; CHECK-NEXT: %4 = bitcast double* %1 to float* -; CHECK-NEXT: %5 = load float, float* %4, align 4 -; CHECK-NEXT: %res11 = call float @llvm.pow.f32(float %3, float %5) -; CHECK-NEXT: %6 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %6, align 4 -; CHECK-NEXT: %7 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res11, float* %7, align 4 -; CHECK-NEXT: %8 = load double, double* %1, align 8 -; CHECK-NEXT: store double %x, double* %1, align 8 -; CHECK-NEXT: %9 = bitcast double* %1 to float* -; CHECK-NEXT: %10 = load float, float* %9, align 4 -; CHECK-NEXT: %res22 = call float @llvm.powi.f32.i16(float %10, i16 2) -; CHECK-NEXT: %11 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %11, align 4 -; CHECK-NEXT: %12 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res22, float* %12, align 4 -; CHECK-NEXT: %13 = load double, double* %1, align 8 -; CHECK-NEXT: store double %8, double* %1, align 8 -; CHECK-NEXT: %14 = bitcast double* %1 to float* -; CHECK-NEXT: %15 = load float, float* %14, align 4 -; CHECK-NEXT: store double %13, double* %1, align 8 -; CHECK-NEXT: %16 = bitcast double* %1 to float* -; CHECK-NEXT: %17 = load float, float* %16, align 4 -; CHECK-NEXT: %res = fadd float %15, %17 -; CHECK-NEXT: %18 = bitcast double* %1 to i64* -; CHECK-NEXT: store i64 0, i64* %18, align 4 -; CHECK-NEXT: %19 = bitcast double* %1 to float* -; CHECK-NEXT: store float %res, float* %19, align 4 -; CHECK-NEXT: %20 = load double, double* %1, align 8 -; CHECK-NEXT: call void @llvm.nvvm.barrier0() -; CHECK-NEXT: ret double %20 -; CHECK-NEXT: } +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { +; CHECK-DAG: %1 = call double @__enzyme_mpfr_64_52_func_pow(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %2 = call double @__enzyme_mpfr_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7) +; CHECK-DAG: %3 = call double @__enzyme_mpfr_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7) +; CHECK-DAG: %res = call double @__enzyme_mpfr_64_52_binop_fadd(double %2, double %3, i64 3, i64 7) +; CHECK-DAG: call void @llvm.nvvm.barrier0() +; CHECK-DAG: ret double %res +; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/select.ll b/enzyme/test/Enzyme/Truncate/select.ll index 58b4a58ef91b..afc41219fed8 100644 --- a/enzyme/test/Enzyme/Truncate/select.ll +++ b/enzyme/test/Enzyme/Truncate/select.ll @@ -6,22 +6,29 @@ define double @f(double %x, double %y, i1 %cond) { ret double %res } -declare double (double, double, i1)* @__enzyme_truncate_func(...) +declare double (double, double, i1)* @__enzyme_truncate_mem_func(...) +declare double (double, double, i1)* @__enzyme_truncate_op_func(...) define double @tester(double %x, double %y, i1 %cond) { entry: - %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_func(double (double, double, i1)* @f, i64 64, i64 32) + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_mem_func(double (double, double, i1)* @f, i64 64, i64 32) + %res = call double %ptr(double %x, double %y, i1 %cond) + ret double %res +} + +define double @tester2(double %x, double %y, i1 %cond) { +entry: + %ptr = call double (double, double, i1)* (...) @__enzyme_truncate_op_func(double (double, double, i1)* @f, i64 64, i64 32) %res = call double %ptr(double %x, double %y, i1 %cond) ret double %res } ; CHECK: define double @tester(double %x, double %y, i1 %cond) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %res = call double @trunc_64_32f(double %x, double %y, i1 %cond) +; CHECK-NEXT: %res = call double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) ; CHECK-NEXT: ret double %res -; CHECK-NEXT: } -; CHECK: define internal double @trunc_64_32f(double %x, double %y, i1 %cond) { +; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y, i1 %cond) { ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: store double %x, double* %1, align 8 ; CHECK-DAG: %2 = bitcast double* %1 to float* @@ -36,4 +43,7 @@ entry: ; CHECK-DAG: store float %res, float* %7, align 4 ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: ret double %8 -; CHECK-NEXT: } + +; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y, i1 %cond) { +; CHECK-DAG: %res = select i1 %cond, double %x, double %y +; CHECK-DAG: ret double %res diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 0f346a26f0d2..a57f33fcdfdb 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -8,22 +8,29 @@ define void @f(double* %x) { ret void } -declare void (double*)* @__enzyme_truncate_func(...) +declare void (double*)* @__enzyme_truncate_mem_func(...) +declare void (double*)* @__enzyme_truncate_op_func(...) define void @tester(double* %data) { entry: - %ptr = call void (double*)* (...) @__enzyme_truncate_func(void (double*)* @f, i64 64, i64 32) + %ptr = call void (double*)* (...) @__enzyme_truncate_mem_func(void (double*)* @f, i64 64, i64 32) + call void %ptr(double* %data) + ret void +} +define void @tester_op(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 32) + call void %ptr(double* %data) + ret void +} +define void @tester_op_mpfr(double* %data) { +entry: + %ptr = call void (double*)* (...) @__enzyme_truncate_op_func(void (double*)* @f, i64 64, i64 3, i64 7) call void %ptr(double* %data) ret void } -; CHECK: define void @tester(double* %data) -; CHECK-NEXT: entry: -; CHECK-NEXT: call void @trunc_64_32f(double* %data) -; CHECK-NEXT: ret void -; CHECK-NEXT: } - -; CHECK: define internal void @trunc_64_32f(double* %x) +; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52to32_23_f(double* %x) ; CHECK-DAG: %1 = alloca double, align 8 ; CHECK-DAG: %y = load double, double* %x, align 8 ; CHECK-DAG: store double %y, double* %1, align 8 @@ -40,3 +47,18 @@ entry: ; CHECK-DAG: %8 = load double, double* %1, align 8 ; CHECK-DAG: store double %8, double* %x, align 8 ; CHECK-DAG: ret void + +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to32_23_f(double* %x) { +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: %enzyme_trunc = fptrunc double %y to float +; CHECK-DAG: %enzyme_trunc1 = fptrunc double %y to float +; CHECK-DAG: %m = fmul float %enzyme_trunc, %enzyme_trunc1 +; CHECK-DAG: %enzyme_exp = fpext float %m to double +; CHECK-DAG: store double %enzyme_exp, double* %x, align 8 +; CHECK-DAG: ret void + +; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { +; CHECK-DAG: %y = load double, double* %x, align 8 +; CHECK-DAG: %m = call double @__enzyme_mpfr_64_52_binop_fmul(double %y, double %y, i64 3, i64 7) +; CHECK-DAG: store double %m, double* %x, align 8 +; CHECK-DAG: ret void diff --git a/enzyme/test/Enzyme/Truncate/value.ll b/enzyme/test/Enzyme/Truncate/value.ll index 51f00401078d..9f87d00d2173 100644 --- a/enzyme/test/Enzyme/Truncate/value.ll +++ b/enzyme/test/Enzyme/Truncate/value.ll @@ -1,18 +1,18 @@ ; RUN: if [ %llvmver -gt 12 ]; then if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi; fi ; RUN: if [ %llvmver -gt 12 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s; fi -declare double @__enzyme_truncate_value(double, i64, i64) -declare double @__enzyme_expand_value(double, i64, i64) +declare double @__enzyme_truncate_mem_value(double, i64, i64) +declare double @__enzyme_expand_mem_value(double, i64, i64) define double @expand_tester(double %a, double * %c) { entry: - %b = call double @__enzyme_expand_value(double %a, i64 64, i64 32) + %b = call double @__enzyme_expand_mem_value(double %a, i64 64, i64 32) ret double %b } define double @truncate_tester(double %a) { entry: - %b = call double @__enzyme_truncate_value(double %a, i64 64, i64 32) + %b = call double @__enzyme_truncate_mem_value(double %a, i64 64, i64 32) ret double %b } diff --git a/enzyme/test/Integration/BUILD b/enzyme/test/Integration/BUILD new file mode 100644 index 000000000000..ab14f8fca82e --- /dev/null +++ b/enzyme/test/Integration/BUILD @@ -0,0 +1,29 @@ +# Enzyme integration tests. + +load("@llvm-project//llvm:lit_test.bzl", "lit_test") + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + "//:enzyme-clang", + "//:enzyme-clang++", + "//:enzyme-opt", + "//test:lit.cfg.py", + "//test:lit.site.cfg.py", + "@llvm-project//clang:builtin_headers_gen", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["**/*.h"]), + ) + for src in glob( + [ + "**/*.c", + "**/.cpp", + ], + exclude = ["**/*omp*.c"], + ) +] diff --git a/enzyme/test/Integration/CMakeLists.txt b/enzyme/test/Integration/CMakeLists.txt index 7a14214b46cd..98171f188d86 100644 --- a/enzyme/test/Integration/CMakeLists.txt +++ b/enzyme/test/Integration/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(ForwardModeVector) add_subdirectory(ReverseMode) add_subdirectory(BatchMode) add_subdirectory(Sparse) +add_subdirectory(Truncate) # Run regression and unit tests add_lit_testsuite(check-enzyme-integration "Running enzyme integration tests" diff --git a/enzyme/test/Integration/ForwardMode/loops.c b/enzyme/test/Integration/ForwardMode/loops.c index 612839248cc2..33bef07a2d16 100644 --- a/enzyme/test/Integration/ForwardMode/loops.c +++ b/enzyme/test/Integration/ForwardMode/loops.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopsdouble.c b/enzyme/test/Integration/ForwardMode/loopsdouble.c index 8c34740c80d5..4faa0cad74f3 100644 --- a/enzyme/test/Integration/ForwardMode/loopsdouble.c +++ b/enzyme/test/Integration/ForwardMode/loopsdouble.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/loopstriple.c b/enzyme/test/Integration/ForwardMode/loopstriple.c index 060e74d009f7..84e98ab1694e 100644 --- a/enzyme/test/Integration/ForwardMode/loopstriple.c +++ b/enzyme/test/Integration/ForwardMode/loopstriple.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/rwrloop.c b/enzyme/test/Integration/ForwardMode/rwrloop.c index dc71b0eead20..32e548029f6d 100644 --- a/enzyme/test/Integration/ForwardMode/rwrloop.c +++ b/enzyme/test/Integration/ForwardMode/rwrloop.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_fwddiff(void*, ...); diff --git a/enzyme/test/Integration/ForwardMode/sumtil.c b/enzyme/test/Integration/ForwardMode/sumtil.c index 9e9286f97157..5d6369df4909 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil.c +++ b/enzyme/test/Integration/ForwardMode/sumtil.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardMode/sumtil2.c b/enzyme/test/Integration/ForwardMode/sumtil2.c index 32d289703e69..18428ca371d6 100644 --- a/enzyme/test/Integration/ForwardMode/sumtil2.c +++ b/enzyme/test/Integration/ForwardMode/sumtil2.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern double __enzyme_fwddiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ForwardModeVector/binops.c b/enzyme/test/Integration/ForwardModeVector/binops.c index 27c786a86980..a3b66f27ea31 100644 --- a/enzyme/test/Integration/ForwardModeVector/binops.c +++ b/enzyme/test/Integration/ForwardModeVector/binops.c @@ -7,10 +7,7 @@ // RUN: %clang -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include +#include "../test_utils.h" /* #ifdef __cplusplus @@ -28,18 +25,6 @@ threshold) { if (fabs(f1-f2) > threshold) return false; return true; #endif */ -#define APPROX_EQ(LHS, RHS, THRES) \ - { \ - if (__builtin_fabs(LHS - RHS) > THRES) { \ - fprintf(stderr, \ - "Assertion Failed: fabs( [%s = %g] - [%s = %g] ) > %g at %s:%d " \ - "(%s)\n", \ - #LHS, LHS, #RHS, RHS, THRES, __FILE__, __LINE__, \ - __PRETTY_FUNCTION__); \ - abort(); \ - } \ - }; - typedef struct { double dx, dy; diff --git a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c index 1236a86388e7..27bd89f36e40 100644 --- a/enzyme/test/Integration/ReverseMode/allocatedtape_err.c +++ b/enzyme/test/Integration/ReverseMode/allocatedtape_err.c @@ -7,8 +7,9 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -g -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -Xclang -verify; fi -#include -#include "../test_utils.h" +extern int enzyme_allocated; +extern int enzyme_tape; +double sin(double); void __enzyme_reverse(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 8c929ed054d0..7cc7438cc932 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -42,6 +42,11 @@ void my_dgemm(char layout, char transA, char transB, int M, int N, int K, double inDerivative = true; } +void ow_dgemm(char layout, char transA, char transB, int M, int N, int K, double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { + cblas_dgemm(layout, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + inDerivative = true; +} + static void dotTests() { @@ -226,8 +231,8 @@ static void gemvTests() { assert(foundCalls.size() > 2); auto A_cache = (double*)foundCalls[0].pout_arg1; - cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, N); - inputs[4] = BlasInfo(A_cache, layout, M, N, N); + cblas_dlacpy(layout, '\0', M, N, A, lda, A_cache, M); + inputs[4] = BlasInfo(A_cache, layout, M, N, M); auto B_cache = (double*)foundCalls[1].pout_arg1; cblas_dcopy(trans ? M : N, B, incB, B_cache, 1); inputs[5] = BlasInfo(B_cache, trans ? M : N, 1); @@ -244,7 +249,7 @@ static void gemvTests() { lda); // dB = alpha * trans(A) * dC + dB - cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, N, dC, incC, 1.0, dB, incB); + cblas_dgemv(layout, (char)transpose(transA), M, N, alpha, A_cache, M, dC, incC, 1.0, dB, incB); // dY = beta * dY cblas_dscal(trans ? N : M, beta, dC, incC); @@ -374,6 +379,78 @@ static void gemmTests() { // should be the same). checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + Test = "GEMM overwrite"; + + init(); + __enzyme_autodiff((void*) ow_dgemm, + enzyme_const, layout, + enzyme_const, transA, + enzyme_const, transB, + enzyme_const, M, + enzyme_const, N, + enzyme_const, K, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + assert(foundCalls.size() > 2); + auto A_cache = (double*)foundCalls[0].pout_arg1; + cblas_dlacpy(layout, '\0', (!transA_bool) ? M : K, (!transA_bool) ? K : M, A, lda, A_cache, (!transA_bool) ? M : K); + inputs[4] = BlasInfo(A_cache, layout, (!transA_bool) ? M : K, (!transA_bool) ? K : M, (!transA_bool) ? M : K); + auto B_cache = (double*)foundCalls[1].pout_arg1; + cblas_dlacpy(layout, '\0', (!transB_bool) ? K : N, (!transB_bool) ? N : K, B, incB, B_cache, (!transB_bool) ? K : N); + inputs[5] = BlasInfo(B_cache, layout, (!transB_bool) ? K : N, (!transB_bool) ? N : K, (!transB_bool) ? K : N); + + ow_dgemm(layout, (char)transA, (char)transB, M, N, K, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + // dA = + my_dgemm(layout, + transA_bool ? (char)transB : (char)CBLAS_TRANSPOSE::CblasNoTrans, + transA_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transB), + transA_bool ? K : M, + transA_bool ? M : K, + N, + alpha, + transA_bool ? B_cache : dC, + transA_bool ? ( (!transB_bool) ? K : N ) : incC, + transA_bool ? dC : B_cache, + transA_bool ? incC : ( (!transB_bool) ? K : N), + 1.0, dA, lda); + + // dB = + my_dgemm(layout, + transB_bool ? (char)CBLAS_TRANSPOSE::CblasTrans : (char)transpose(transA), + transB_bool ? (char)transA : (char)CBLAS_TRANSPOSE::CblasNoTrans, //transB, + transB_bool ? N : K, + transB_bool ? K : N, + M, + alpha, + transB_bool ? dC : A_cache, + transB_bool ? incC : ( (!transA_bool) ? M : K), + transB_bool ? A_cache : dC, + transB_bool ? ( (!transA_bool) ? M : K) : incC, + 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); } diff --git a/enzyme/test/Integration/ReverseMode/blas_gemm2.c b/enzyme/test/Integration/ReverseMode/blas_gemm2.c index a417e0418cd9..b06474282634 100644 --- a/enzyme/test/Integration/ReverseMode/blas_gemm2.c +++ b/enzyme/test/Integration/ReverseMode/blas_gemm2.c @@ -8,9 +8,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi -#include -#include -#include #include "../test_utils.h" #include "../blas_inline.h" diff --git a/enzyme/test/Integration/ReverseMode/boundissue.c b/enzyme/test/Integration/ReverseMode/boundissue.c index f297ebbdda97..4d41f2bf482a 100644 --- a/enzyme/test/Integration/ReverseMode/boundissue.c +++ b/enzyme/test/Integration/ReverseMode/boundissue.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/cachefwd.c b/enzyme/test/Integration/ReverseMode/cachefwd.c index ab56e3dd0853..6e1b058a7004 100644 --- a/enzyme/test/Integration/ReverseMode/cachefwd.c +++ b/enzyme/test/Integration/ReverseMode/cachefwd.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/cmplx.cpp b/enzyme/test/Integration/ReverseMode/cmplx.cpp index 73559beb94fb..bec8aa0d831b 100644 --- a/enzyme/test/Integration/ReverseMode/cmplx.cpp +++ b/enzyme/test/Integration/ReverseMode/cmplx.cpp @@ -11,10 +11,6 @@ #include "../test_utils.h" -#include -#include - -#include #include // std::complex, std::abs, std::arg void __enzyme_autodiff(...); diff --git a/enzyme/test/Integration/ReverseMode/customcombined.c b/enzyme/test/Integration/ReverseMode/customcombined.c index 5ce2c1df2b07..cbd86d0452ef 100644 --- a/enzyme/test/Integration/ReverseMode/customcombined.c +++ b/enzyme/test/Integration/ReverseMode/customcombined.c @@ -32,7 +32,7 @@ void* augment_square_(const double* src, const double *d_src, double* dest, doub // intentionally incorrect for debugging *dest = 7.0; *d_dest = 11.0; - return NULL; + return (void*)0; } int gradient = 0; diff --git a/enzyme/test/Integration/ReverseMode/customlog1p.c b/enzyme/test/Integration/ReverseMode/customlog1p.c index 95c0697fff77..6f563030f9b3 100644 --- a/enzyme/test/Integration/ReverseMode/customlog1p.c +++ b/enzyme/test/Integration/ReverseMode/customlog1p.c @@ -17,10 +17,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/dbginfo.c b/enzyme/test/Integration/ReverseMode/dbginfo.c index 06767e624849..b094dffb5c6d 100644 --- a/enzyme/test/Integration/ReverseMode/dbginfo.c +++ b/enzyme/test/Integration/ReverseMode/dbginfo.c @@ -7,14 +7,27 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - -g | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - -g | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -//#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, double, unsigned); +// May be needed if not built with compiler-rt +double __powidf2(double a, int b) { + const int recip = b < 0; + double r = 1; + while (1) { + if (b & 1) + r *= a; + b /= 2; + if (b == 0) + break; + a *= a; + } + return recip ? 1 / r : r; +} + static double taylorlog(double x, unsigned SINCOSN) { double sum = 0; for(int i=1; i<=SINCOSN; i++) { diff --git a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c index b05a7264ae3c..42daecbaa7ed 100644 --- a/enzyme/test/Integration/ReverseMode/differential_pointer_return.c +++ b/enzyme/test/Integration/ReverseMode/differential_pointer_return.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/eigentensor.cpp b/enzyme/test/Integration/ReverseMode/eigentensor.cpp index 28feb236f4db..6e2da6d65818 100644 --- a/enzyme/test/Integration/ReverseMode/eigentensor.cpp +++ b/enzyme/test/Integration/ReverseMode/eigentensor.cpp @@ -16,6 +16,7 @@ #include "../test_utils.h" +#include void memcpy(float* __restrict dst, float* __restrict src, size_t count) { for(size_t i=0; i -#include -#include -#include +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %loadClangEnzyme %s -S -emit-llvm -o - | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %loadClangEnzyme %s -S -emit-llvm -o - | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O0 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O1 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %loadClangEnzyme %s -S -emit-llvm -o - -mllvm -enzyme-inline=1 -mllvm -enzyme-loose-types | %lli - ; fi #include "../test_utils.h" - float __enzyme_autodiff(void*, float, int); float foo(float inp, int n) { diff --git a/enzyme/test/Integration/ReverseMode/frexp.c b/enzyme/test/Integration/ReverseMode/frexp.c index 4b917e7e282f..08858d51f796 100644 --- a/enzyme/test/Integration/ReverseMode/frexp.c +++ b/enzyme/test/Integration/ReverseMode/frexp.c @@ -7,12 +7,10 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" +extern double frexp ( double num, int* exp ); + double f(double x) { int exp; return frexp(x, &exp); diff --git a/enzyme/test/Integration/ReverseMode/fwdsolve.c b/enzyme/test/Integration/ReverseMode/fwdsolve.c index c7a6ad63040e..5a7ab4cc67e8 100644 --- a/enzyme/test/Integration/ReverseMode/fwdsolve.c +++ b/enzyme/test/Integration/ReverseMode/fwdsolve.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" @@ -25,10 +20,10 @@ void forward_sub(int N, double* __restrict__ L, double * __restrict__ b, double b must be a vector of the same leading dimension as L """ */ - for (size_t i=0; i 1) - for (size_t j=0; j #include "../test_utils.h" typedef struct { diff --git a/enzyme/test/Integration/ReverseMode/headerremat.c b/enzyme/test/Integration/ReverseMode/headerremat.c index b3397ad8bead..a324a519557f 100644 --- a/enzyme/test/Integration/ReverseMode/headerremat.c +++ b/enzyme/test/Integration/ReverseMode/headerremat.c @@ -7,15 +7,8 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" - -#include - __attribute__((noinline)) int evaluate_integrand(const int nr, const int dtheta) diff --git a/enzyme/test/Integration/ReverseMode/inactivefn.c b/enzyme/test/Integration/ReverseMode/inactivefn.c index e48edd3486ed..bbd6696a04d9 100644 --- a/enzyme/test/Integration/ReverseMode/inactivefn.c +++ b/enzyme/test/Integration/ReverseMode/inactivefn.c @@ -17,10 +17,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum.c b/enzyme/test/Integration/ReverseMode/insertsort_sum.c index 7e7d6b0a3ae4..7b42f7613890 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum.c @@ -6,10 +6,6 @@ // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c index 3a643492d30f..321438af8b9a 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_alt.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c index b62150353581..f6bfd1a4c5c0 100644 --- a/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c +++ b/enzyme/test/Integration/ReverseMode/insertsort_sum_min.c @@ -8,10 +8,6 @@ // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - #include "../test_utils.h" -#include -#include -#include -#include #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/integrateconst.cpp b/enzyme/test/Integration/ReverseMode/integrateconst.cpp index 2086ab8b525d..c55247d65e55 100644 --- a/enzyme/test/Integration/ReverseMode/integrateconst.cpp +++ b/enzyme/test/Integration/ReverseMode/integrateconst.cpp @@ -14,7 +14,6 @@ #define BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS #define BOOST_NO_EXCEPTIONS -#include #include #include diff --git a/enzyme/test/Integration/ReverseMode/invsqrt.c b/enzyme/test/Integration/ReverseMode/invsqrt.c index c14ce79ab88d..09ada6757e45 100644 --- a/enzyme/test/Integration/ReverseMode/invsqrt.c +++ b/enzyme/test/Integration/ReverseMode/invsqrt.c @@ -17,13 +17,8 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include -#include -#include - #include "../test_utils.h" +#include // Fast inverse sqrt // Code taken from https://en.wikipedia.org/wiki/Fast_inverse_square_root @@ -74,10 +69,8 @@ int main(int argc, char *argv[]) { double *A = (double*)malloc(sizeof(double) * n); for(int i=0; i -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loopsdouble.c b/enzyme/test/Integration/ReverseMode/loopsdouble.c index 8108a9813c63..c9ffc8b6bde6 100644 --- a/enzyme/test/Integration/ReverseMode/loopsdouble.c +++ b/enzyme/test/Integration/ReverseMode/loopsdouble.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/loopstriple.c b/enzyme/test/Integration/ReverseMode/loopstriple.c index 490502145810..a1aa76444893 100644 --- a/enzyme/test/Integration/ReverseMode/loopstriple.c +++ b/enzyme/test/Integration/ReverseMode/loopstriple.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/manydiv.c b/enzyme/test/Integration/ReverseMode/manydiv.c index a0da1f6499b3..e3937c6bb956 100644 --- a/enzyme/test/Integration/ReverseMode/manydiv.c +++ b/enzyme/test/Integration/ReverseMode/manydiv.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/manymax.c b/enzyme/test/Integration/ReverseMode/manymax.c index 90b61041ec1e..4cf31bf1cfa7 100644 --- a/enzyme/test/Integration/ReverseMode/manymax.c +++ b/enzyme/test/Integration/ReverseMode/manymax.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/metamalloc.c b/enzyme/test/Integration/ReverseMode/metamalloc.c index 96a848fd9698..3765a0c803fd 100644 --- a/enzyme/test/Integration/ReverseMode/metamalloc.c +++ b/enzyme/test/Integration/ReverseMode/metamalloc.c @@ -7,20 +7,16 @@ // RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -Xclang -new-struct-path-tbaa -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); struct { int count; -void* (*allocfn)(size_t); +void* (*allocfn)(unsigned long); } tup = {0, malloc}; __attribute__((noinline)) -void* metamalloc(size_t size) { +void* metamalloc(unsigned long size) { void* ret = tup.allocfn(size); //if (ret != 0) // tup.count++; @@ -38,7 +34,7 @@ double alldiv(double x) { } -static void* (*sallocfn)(size_t) = malloc; +static void* (*sallocfn)(unsigned long) = malloc; __attribute__((noinline)) void* smetamalloc(int size) { return sallocfn(size); diff --git a/enzyme/test/Integration/ReverseMode/metarwr.c b/enzyme/test/Integration/ReverseMode/metarwr.c index 4efcd475f68f..3e3310c634bd 100644 --- a/enzyme/test/Integration/ReverseMode/metarwr.c +++ b/enzyme/test/Integration/ReverseMode/metarwr.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); @@ -19,7 +15,7 @@ void call(double* __restrict__ a, long** data) { long* segment = data[0]; long size = segment[1] - segment[0]; printf("seg[1]=%d seg[0]=%d\n", segment[1], segment[0]); - for (size_t i=0; i -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c index 5dea68e793da..40444dd8306f 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simple.c @@ -5,12 +5,7 @@ // RUN: %clang -std=c11 %O0TBAA %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - - -#include -#include -#include -#include +// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - q #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c index e2b3f9788fca..22e6fc4dedc7 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplefda.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c index b4857429b6fe..7de76117e765 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpleps.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c index 61b81d0a57bf..6c602b32da3a 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simpler.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c index bcf3ea958485..5417fce6b8ea 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-simplest.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c index eb783098cf02..d64e9a70da9a 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1-sp.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/mixedstruct1.c b/enzyme/test/Integration/ReverseMode/mixedstruct1.c index 3ab79b584ef9..e1218dbc0541 100644 --- a/enzyme/test/Integration/ReverseMode/mixedstruct1.c +++ b/enzyme/test/Integration/ReverseMode/mixedstruct1.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/multivecmaxC.c b/enzyme/test/Integration/ReverseMode/multivecmaxC.c index c39559adb2de..460036659ffd 100644 --- a/enzyme/test/Integration/ReverseMode/multivecmaxC.c +++ b/enzyme/test/Integration/ReverseMode/multivecmaxC.c @@ -10,9 +10,6 @@ // RUN: %clang++ -ffast-math -O2 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -fno-exceptions %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang++ -ffast-math -O3 -fno-vectorize -fno-slp-vectorize -fno-unroll-loops -fno-exceptions %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); @@ -21,7 +18,7 @@ extern void __enzyme_autodiff(void*, double*, double*, int); }*/ double reduce_max(double* vec, int size) { - double ret = -INFINITY; + double ret = -10000000; double *maxes = (double*)malloc(sizeof(double)*size); int count = 0; for (int i = 0; i < size; i++) { diff --git a/enzyme/test/Integration/ReverseMode/mycos.c b/enzyme/test/Integration/ReverseMode/mycos.c index 1d7d4c8e85f3..fab9f17085eb 100644 --- a/enzyme/test/Integration/ReverseMode/mycos.c +++ b/enzyme/test/Integration/ReverseMode/mycos.c @@ -15,22 +15,20 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi -#include -#include -#include - #include "../test_utils.h" +double pow(double, double); + __attribute__((noinline)) -uint64_t factorial(uint64_t x) { +unsigned long long factorial(unsigned long long x) { if (x == 0) return 1; return x * factorial(x-1); } double my_sin(double x) { double result = 0; - uint64_t N = 12; - for(uint64_t i=0; i<=N; i++) { + unsigned long long N = 12; + for(unsigned long long i=0; i<=N; i++) { if (i % 2 == 0) continue; result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1); } @@ -38,14 +36,14 @@ double my_sin(double x) { } -uint64_t __enzyme_iter(uint64_t, uint64_t); +unsigned long long __enzyme_iter(unsigned long long, unsigned long long); double __enzyme_autodiff(void*, double); double my_sin2(double x) { double result = 0; - uint64_t N = __enzyme_iter(12, 1); - for(uint64_t i=0; i<=N; i++) { + unsigned long long N = __enzyme_iter(12, 1); + for(unsigned long long i=0; i<=N; i++) { if (i % 2 == 0) continue; result += pow(x, i) / factorial(i) * (i % 4 == 1 ? 1 : -1); } diff --git a/enzyme/test/Integration/ReverseMode/omp.c b/enzyme/test/Integration/ReverseMode/omp.c index 4d6f0bd164c5..5d4f8ae09970 100644 --- a/enzyme/test/Integration/ReverseMode/omp.c +++ b/enzyme/test/Integration/ReverseMode/omp.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp2.c b/enzyme/test/Integration/ReverseMode/omp2.c index 8cdd06545fd1..44f315fe28f9 100644 --- a/enzyme/test/Integration/ReverseMode/omp2.c +++ b/enzyme/test/Integration/ReverseMode/omp2.c @@ -8,10 +8,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/omp3.c b/enzyme/test/Integration/ReverseMode/omp3.c index 59bdca34cf9a..970348a7cedd 100644 --- a/enzyme/test/Integration/ReverseMode/omp3.c +++ b/enzyme/test/Integration/ReverseMode/omp3.c @@ -9,8 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -# include -# include #include "../test_utils.h" void msg(double* in, int *len, unsigned int slen) { diff --git a/enzyme/test/Integration/ReverseMode/omp6.c b/enzyme/test/Integration/ReverseMode/omp6.c index 8a6f34fe235f..abd170d07dfd 100644 --- a/enzyme/test/Integration/ReverseMode/omp6.c +++ b/enzyme/test/Integration/ReverseMode/omp6.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -fno-vectorize -fno-unroll-loops -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -# include -# include -#include - #include "../test_utils.h" __attribute__((noinline)) diff --git a/enzyme/test/Integration/ReverseMode/omp_two.c b/enzyme/test/Integration/ReverseMode/omp_two.c index 1f95e93f2520..8d3dac92ca75 100644 --- a/enzyme/test/Integration/ReverseMode/omp_two.c +++ b/enzyme/test/Integration/ReverseMode/omp_two.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/ompbound.c b/enzyme/test/Integration/ReverseMode/ompbound.c index 4d6f0bd164c5..5d4f8ae09970 100644 --- a/enzyme/test/Integration/ReverseMode/ompbound.c +++ b/enzyme/test/Integration/ReverseMode/ompbound.c @@ -9,10 +9,6 @@ // RUN: %clang -fopenmp -std=c11 -O2 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out // RUN: %clang -fopenmp -std=c11 -O3 -fno-vectorize -fno-unroll-loops %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %clang -fopenmp -x ir - -o %s.out && %s.out -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/posix_memalign.c b/enzyme/test/Integration/ReverseMode/posix_memalign.c index 48aab315b95a..b50a405dd70c 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalign.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalign.c @@ -7,15 +7,9 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" -int posix_memalign(void **memptr, size_t alignment, size_t size); +int posix_memalign(void **memptr, unsigned long alignment, unsigned long size); float __enzyme_autodiff(void*, float, int); diff --git a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c index 0a01ef095142..7336444421a8 100644 --- a/enzyme/test/Integration/ReverseMode/posix_memalignfor.c +++ b/enzyme/test/Integration/ReverseMode/posix_memalignfor.c @@ -7,15 +7,9 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" -int posix_memalign(void **memptr, size_t alignment, size_t size); +int posix_memalign(void **memptr, unsigned long alignment, unsigned long size); float __enzyme_autodiff(void*, float, int); diff --git a/enzyme/test/Integration/ReverseMode/readwriteread.c b/enzyme/test/Integration/ReverseMode/readwriteread.c index ecfdce54d27e..adb5afca594a 100644 --- a/enzyme/test/Integration/ReverseMode/readwriteread.c +++ b/enzyme/test/Integration/ReverseMode/readwriteread.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/recurse.c b/enzyme/test/Integration/ReverseMode/recurse.c index a53041e77188..ca06c745411f 100644 --- a/enzyme/test/Integration/ReverseMode/recurse.c +++ b/enzyme/test/Integration/ReverseMode/recurse.c @@ -9,10 +9,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/remat.c b/enzyme/test/Integration/ReverseMode/remat.c index b228bcc13c99..c9b771b17d1b 100644 --- a/enzyme/test/Integration/ReverseMode/remat.c +++ b/enzyme/test/Integration/ReverseMode/remat.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -// test.c -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rematSimple.c b/enzyme/test/Integration/ReverseMode/rematSimple.c index 7b95f9cb9506..788e0c258c8e 100644 --- a/enzyme/test/Integration/ReverseMode/rematSimple.c +++ b/enzyme/test/Integration/ReverseMode/rematSimple.c @@ -3,10 +3,6 @@ // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - ; fi // RUN: if [ %llvmver -ge 12 ]; then %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - ; fi -// test.c -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rwrloop.c b/enzyme/test/Integration/ReverseMode/rwrloop.c index cdf9e3774553..74b9acb7b897 100644 --- a/enzyme/test/Integration/ReverseMode/rwrloop.c +++ b/enzyme/test/Integration/ReverseMode/rwrloop.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/rwrmeta.c b/enzyme/test/Integration/ReverseMode/rwrmeta.c index 34a15d5c93d4..985ccdd44f1a 100644 --- a/enzyme/test/Integration/ReverseMode/rwrmeta.c +++ b/enzyme/test/Integration/ReverseMode/rwrmeta.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme --enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" double __enzyme_autodiff(void*, ...); diff --git a/enzyme/test/Integration/ReverseMode/smallrealloc.c b/enzyme/test/Integration/ReverseMode/smallrealloc.c index 51e29ccdfce5..1244a2f12cf1 100644 --- a/enzyme/test/Integration/ReverseMode/smallrealloc.c +++ b/enzyme/test/Integration/ReverseMode/smallrealloc.c @@ -7,11 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include - #include "../test_utils.h" diff --git a/enzyme/test/Integration/ReverseMode/sret.cpp b/enzyme/test/Integration/ReverseMode/sret.cpp index 2b5315d5d35c..16751fe3e318 100644 --- a/enzyme/test/Integration/ReverseMode/sret.cpp +++ b/enzyme/test/Integration/ReverseMode/sret.cpp @@ -8,9 +8,6 @@ // RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S #include "../test_utils.h" -#include -#include -#include typedef struct { double df[3]; diff --git a/enzyme/test/Integration/ReverseMode/subdoublestore.c b/enzyme/test/Integration/ReverseMode/subdoublestore.c index f411d98f3e7f..130f79c99979 100644 --- a/enzyme/test/Integration/ReverseMode/subdoublestore.c +++ b/enzyme/test/Integration/ReverseMode/subdoublestore.c @@ -7,12 +7,6 @@ // RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include -#include -#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff diff --git a/enzyme/test/Integration/ReverseMode/sugar.cpp b/enzyme/test/Integration/ReverseMode/sugar.cpp new file mode 100644 index 000000000000..8524342e1bfc --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/sugar.cpp @@ -0,0 +1,91 @@ +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O0 %s -mllvm -print-before-all -mllvm -print-after-all -mllvm -print-module-scope -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++17 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi + +#include "../test_utils.h" + +#include + +double foo(double x, double y) { return x * y; } + +double square(double x) { return x * x; } + +struct pair { + double x; + double y; +}; + +int main() { + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq2 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff>(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq3 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq3_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + enzyme::Active x1{3.1}; + enzyme::tuple< enzyme::tuple, double > dsq = enzyme::autodiff(square, x1); + double dd = enzyme::get<0>(enzyme::get<0>(dsq)); + printf("dsq4 = %f\n", dd); + APPROX_EQ(dd, 3.1*2, 1e-10); + double prim = enzyme::get<1>(dsq); + printf("dsq4_prim = %f\n", prim); + APPROX_EQ(prim, 3.1*3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + printf("dmul %f %f\n", y1, y2); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + } + + { + auto y = enzyme::autodiff(foo, enzyme::Active(3.1), enzyme::Active(2.7)); + auto y1 = enzyme::get<0>(enzyme::get<0>(y)); + auto y2 = enzyme::get<1>(enzyme::get<0>(y)); + auto prim = enzyme::get<1>(y); + printf("dmul2 %f %f\n", y1, y2); + printf("dmul_prim %f\n", prim); + APPROX_EQ(y1, 2.7, 1e-10); + APPROX_EQ(y2, 3.1, 1e-10); + APPROX_EQ(prim, 2.7*3.1, 1e-10); + } + + { + auto &&[z1, z2] = __enzyme_autodiff((void*)foo, enzyme_out, 3.1, enzyme_out, 2.7); + printf("dmul2 %f %f\n", z1, z2); + APPROX_EQ(z1, 2.7, 1e-10); + APPROX_EQ(z2, 3.1, 1e-10); + } + +} diff --git a/enzyme/test/Integration/ReverseMode/sumtil.c b/enzyme/test/Integration/ReverseMode/sumtil.c index 0a2b0502c2bc..f2a34b228afe 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil.c +++ b/enzyme/test/Integration/ReverseMode/sumtil.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/sumtil2.c b/enzyme/test/Integration/ReverseMode/sumtil2.c index aac316c7c4ea..85cea6c94095 100644 --- a/enzyme/test/Integration/ReverseMode/sumtil2.c +++ b/enzyme/test/Integration/ReverseMode/sumtil2.c @@ -7,10 +7,6 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -#include -#include -#include - #include "../test_utils.h" extern void __enzyme_autodiff(void*, double*, double*, int); diff --git a/enzyme/test/Integration/ReverseMode/taylorlog.c b/enzyme/test/Integration/ReverseMode/taylorlog.c index 649dd4fff243..522928e3f60f 100644 --- a/enzyme/test/Integration/ReverseMode/taylorlog.c +++ b/enzyme/test/Integration/ReverseMode/taylorlog.c @@ -7,14 +7,27 @@ // RUN: %clang -std=c11 -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - // RUN: %clang -std=c11 -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -enzyme-inline=1 -S | %lli - -//#include - #include "../test_utils.h" #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, double, unsigned); +// May be needed if not built with compiler-rt +double __powidf2(double a, int b) { + const int recip = b < 0; + double r = 1; + while (1) { + if (b & 1) + r *= a; + b /= 2; + if (b == 0) + break; + a *= a; + } + return recip ? 1 / r : r; +} + static double taylorlog(double x, unsigned SINCOSN) { double sum = 0; for(int i=1; i<=SINCOSN; i++) { diff --git a/enzyme/test/Integration/Truncate/CMakeLists.txt b/enzyme/test/Integration/Truncate/CMakeLists.txt new file mode 100644 index 000000000000..65187869f1ae --- /dev/null +++ b/enzyme/test/Integration/Truncate/CMakeLists.txt @@ -0,0 +1,8 @@ +add_lit_testsuite(check-enzyme-integration-truncate "Running enzyme fp truncation integration tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-enzyme-integration-truncate PROPERTIES FOLDER "Tests") + diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp new file mode 100644 index 000000000000..53cf859c84da --- /dev/null +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -0,0 +1,118 @@ +// RUN: %clang -DTRUNC_OP -O0 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -DTRUNC_OP -O2 -ffast-math %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang -O1 -g %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - + +#include + +#include "../test_utils.h" + +#define N 10 + +double simple_add(double a, double b) { + return a + b; +} +double intrinsics(double a, double b) { + return sqrt(a) * pow(b, 2); +} +// TODO trunc mem mode +double constt(double a, double b) { + return 2; +} +double compute(double *A, double *B, double *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] * 2 + B[i] * sqrt(A[i]); + } + return C[0]; +} + +typedef double (*fty)(double *, double *, double *, int); + +typedef double (*fty2)(double, double); + +extern fty __enzyme_truncate_mem_func_2(...); +extern fty2 __enzyme_truncate_mem_func(...); +extern fty __enzyme_truncate_op_func_2(...); +extern fty2 __enzyme_truncate_op_func(...); +extern double __enzyme_truncate_mem_value(...); +extern double __enzyme_expand_mem_value(...); + +#define FROM 64 +#define TO 32 + +#define TEST(F) do { + + +int main() { + + #ifdef TRUNC_MEM + { + double a = 1; + APPROX_EQ( + __enzyme_expand_mem_value( + __enzyme_truncate_mem_value(a, FROM, TO) , FROM, TO), + a, 1e-10); + } + + { + double a = 2; + double b = 3; + double truth = simple_add(a, b); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(simple_add, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + { + double a = 2; + double b = 3; + double truth = intrinsics(a, b); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } + #endif + // { + // double a = 2; + // double b = 3; + // double truth = intrinsics(a, b); + // a = __enzyme_truncate_mem_value(a, FROM, TO); + // b = __enzyme_truncate_mem_value(b, FROM, TO); + // double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); + // APPROX_EQ(trunc, truth, 1e-5); + // } + + #ifdef TRUNC_OP + { + double A[N]; + double B[N]; + double C[N]; + double D[N]; + + + for (int i = 0; i < N; i++) { + A[i] = 1 + i % 5; + B[i] = 1 + i % 3; + } + + compute(A, B, D, N); + + // for (int i = 0; i < N; i++) { + // A[i] = __enzyme_truncate_mem_value(A[i], 64, 32); + // B[i] = __enzyme_truncate_mem_value(B[i], 64, 32); + // } + + __enzyme_truncate_op_func_2(compute, 64, 32)(A, B, C, N); + + // for (int i = 0; i < N; i++) { + // C[i] = __enzyme_expand_mem_value(C[i], 64, 32); + // } + + for (int i = 0; i < N; i++) { + APPROX_EQ(D[i], C[i], 1e-5); + } + } + #endif + +} diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp new file mode 100644 index 000000000000..39e5965bda0d --- /dev/null +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -0,0 +1,59 @@ +// Baseline + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="" | %lli - | FileCheck --check-prefix BASELINE %s; fi +// BASELINE: 900000000.560000 + + +// Truncated + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="64to32" | %lli - | FileCheck --check-prefix TO_32 %s; fi +// TO_32: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -S -mllvm --enzyme-truncate-all="11-52to8-23" | %lli - | FileCheck --check-prefix TO_28_23 %s; fi +// TO_28_23: 900000000.000000 + +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DENZYME_TEST_TO_MPFR -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr; %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi +// TO_3_7: 897581056.000000 + +#include + +#ifdef ENZYME_TEST_TO_MPFR +#include +#endif + +#include "../test_utils.h" + +#define N 10 + +#define floatty double + + +__attribute__((noinline)) +floatty simple_add(floatty a, floatty b) { + return a + b; +} +__attribute__((noinline)) +floatty intrinsics(floatty a, floatty b) { + return sqrt(a) * pow(b, 2); +} +__attribute__((noinline)) +floatty compute(floatty *A, floatty *B, floatty *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] / 2 + intrinsics(A[i], simple_add(B[i] * 10000, 0.000001)); + } + return C[0]; +} + +int main() { + floatty A[N]; + floatty B[N]; + floatty C[N]; + + for (int i = 0; i < N; i++) { + A[i] = 1 + i % 5; + B[i] = 1 + i % 3; + } + + compute(A, B, C, N); + printf("%f\n", C[5]); +} diff --git a/enzyme/test/Integration/blas_inline.h b/enzyme/test/Integration/blas_inline.h index 8ac0c16f51b3..b1b988190a5e 100644 --- a/enzyme/test/Integration/blas_inline.h +++ b/enzyme/test/Integration/blas_inline.h @@ -1,5 +1,8 @@ #include #include +#include +#include +#include typedef int32_t integer; typedef double doublereal; diff --git a/enzyme/test/Integration/blasinfra.h b/enzyme/test/Integration/blasinfra.h index 07d540a9b326..cc9d2cc01dd9 100644 --- a/enzyme/test/Integration/blasinfra.h +++ b/enzyme/test/Integration/blasinfra.h @@ -1,5 +1,5 @@ -#include +#include #include #include #include diff --git a/enzyme/test/Integration/test_utils.h b/enzyme/test/Integration/test_utils.h index afcf87d4471b..268922598f54 100644 --- a/enzyme/test/Integration/test_utils.h +++ b/enzyme/test/Integration/test_utils.h @@ -1,7 +1,24 @@ -#include -#include -#include -#include + + +#if defined(__cplusplus) || defined(__APPLE__) +#include +#include +#include +#include +#include +#else +struct _IO_FILE; +extern struct _IO_FILE* stderr; +extern int fprintf(struct _IO_FILE *, const char*, ...); +extern int fflush(struct _IO_FILE *stream); +extern int printf(const char*, ...); +extern void abort(); +extern void free(void *); +extern void* malloc(unsigned long); +extern void *realloc( void *ptr, unsigned long new_size ); +extern void* memcpy( void* dest, const void* src, unsigned long count ); +extern void* memset( void* dest, int, unsigned long count ); +#endif extern #ifdef __cplusplus diff --git a/enzyme/test/MLIR/ActivityAnalysis/region.mlir b/enzyme/test/MLIR/ActivityAnalysis/region.mlir new file mode 100644 index 000000000000..4526e1325661 --- /dev/null +++ b/enzyme/test/MLIR/ActivityAnalysis/region.mlir @@ -0,0 +1,27 @@ +// RUN: %eopt --pass-pipeline="builtin.module(print-activity-analysis{dataflow=false annotate=true})" %s --split-input-file 2>&1 | FileCheck %s + +// A function that contains active and inactive region dataflow + +func.func @region(%x: f64) -> (f64, f64) { + %f0 = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + %c1 = arith.constant 1 : index + %r0:2 = scf.for %arg12 = %c0 to %c10 step %c1 iter_args(%arg13 = %f0, %arg14 = %f0) -> (f64, f64) { + %m = arith.addf %arg13, %x : f64 + scf.yield %m, %arg14 : f64, f64 + } + return %r0#0, %r0#1 : f64, f64 +} + +// CHECK: func.func @region(%arg0: f64) -> (f64, f64) attributes {enzyme.arg_icv0 = false, enzyme.ici = false} { +// CHECK-NEXT: %cst = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 0.000000e+00 : f64 +// CHECK-NEXT: %c0 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 0 : index +// CHECK-NEXT: %c10 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 10 : index +// CHECK-NEXT: %c1 = arith.constant {enzyme.ici = true, enzyme.res_icv0 = true} 1 : index +// CHECK-NEXT: %0:2 = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst, %arg3 = %cst) -> (f64, f64) { +// CHECK-NEXT: %1 = arith.addf %arg2, %arg0 {enzyme.ici = false, enzyme.res_icv0 = false} : f64 +// CHECK-NEXT: scf.yield {enzyme.ici = true} %1, %arg3 : f64, f64 +// CHECK-NEXT: } {enzyme.arg_icv0 = true, enzyme.arg_icv1 = false, enzyme.arg_icv2 = true, enzyme.ici = false, enzyme.res_icv0 = false, enzyme.res_icv1 = true} +// CHECK-NEXT: return {enzyme.ici = true} %0#0, %0#1 : f64, f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/BUILD b/enzyme/test/MLIR/BUILD new file mode 100644 index 000000000000..0af1496b3a5c --- /dev/null +++ b/enzyme/test/MLIR/BUILD @@ -0,0 +1,25 @@ +# MLIR-specific tests for Enzyme. + +load("@llvm-project//llvm:lit_test.bzl", "lit_test") + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + "//:enzymemlir-opt", + "//test:lit.cfg.py", + "//test:lit.site.cfg.py", + "@llvm-project//clang:builtin_headers_gen", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:lli", + "@llvm-project//llvm:not", + ] + glob(["**/*.h"]), + ) + for src in glob( + [ + "**/*.mlir", + ], + ) +] diff --git a/enzyme/test/MLIR/ForwardMode/affine.mlir b/enzyme/test/MLIR/ForwardMode/affine.mlir new file mode 100644 index 000000000000..d0e587409713 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/affine.mlir @@ -0,0 +1,102 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @loop(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %r = affine.for %arg1 = 0 to 10 step 1 iter_args(%arg2 = %cst) -> (f64) { + %n = arith.addf %arg2, %x : f64 + affine.yield %n : f64 + } + return %r : f64 + } + func.func @dloop(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @loop(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeloop + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[r0:.+]]:2 = affine.for %{{.*}} = 0 to 10 iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, f64) { + // CHECK: %[[v1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 + // CHECK: %[[v2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 + // CHECK: affine.yield %[[v2]], %[[v1]] : f64, f64 + // CHECK: } + // CHECK: return %[[r0]]#1 : f64 + + func.func @if_then_else(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dif_then_else(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @if_then_else(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeif_then_else + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) + // CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, f64, f64) { + // CHECK: %[[v3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v5:.+]] = arith.addf %[[v3]], %[[v4]] : f64 + // CHECK: %[[v6:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v6]], %[[v5]], %[[cst]] : f64, f64, f64 + // CHECK: } else { + // CHECK: %[[v3:.+]] = arith.addf %[[arg1]], %[[arg1]] : f64 + // CHECK: %[[v4:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v4]], %[[v3]], %[[cst_0]] : f64, f64, f64 + // CHECK: } + // CHECK: %[[v1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 + // CHECK: %[[v2:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 + // CHECK: return %[[v1]] : f64 + + func.func @if_then(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %mem = memref.alloc() : memref<1xf64> + affine.store %c2, %mem[0] : memref<1xf64> + scf.if %c { + %mul = arith.mulf %x, %x : f64 + affine.store %mul, %mem[0] : memref<1xf64> + } + %r = affine.load %mem[0] : memref<1xf64> + %res = arith.mulf %c2, %r : f64 + return %res : f64 + } + func.func @dif_then(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @if_then(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffeif_then + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { + // CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 + // CHECK-DAG: %[[cst1:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[alloc:.+]] = memref.alloc() : memref<1xf64> + // CHECK: %[[alloc_2:.+]] = memref.alloc() : memref<1xf64> + // CHECK-DAG: %[[cst0:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: affine.store %[[cst0]], %[[alloc]][0] : memref<1xf64> + // CHECK: affine.store %[[cst2]], %[[alloc_2]][0] : memref<1xf64> + // CHECK: scf.if %[[arg2]] { + // CHECK: %[[v4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v5:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 + // CHECK: %[[v6:.+]] = arith.addf %[[v4]], %[[v5]] : f64 + // CHECK: %[[v7:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 + // CHECK: affine.store %[[v6]], %[[alloc]][0] : memref<1xf64> + // CHECK: affine.store %[[v7]], %[[alloc_2]][0] : memref<1xf64> + // CHECK: } + // CHECK: %[[v0:.+]] = affine.load %[[alloc]][0] : memref<1xf64> + // CHECK: %[[v1:.+]] = affine.load %[[alloc_2]][0] : memref<1xf64> + // CHECK: %[[v2:.+]] = arith.mulf %[[v0]], %[[cst2]] : f64 + // CHECK: %[[v3:.+]] = arith.mulf %[[cst2]], %[[v1]] : f64 + // CHECK: return %[[v2]] : f64 +} diff --git a/enzyme/test/MLIR/branch-self-recursive.mlir b/enzyme/test/MLIR/ForwardMode/branch-self-recursive.mlir similarity index 100% rename from enzyme/test/MLIR/branch-self-recursive.mlir rename to enzyme/test/MLIR/ForwardMode/branch-self-recursive.mlir diff --git a/enzyme/test/MLIR/branch.mlir b/enzyme/test/MLIR/ForwardMode/branch.mlir similarity index 100% rename from enzyme/test/MLIR/branch.mlir rename to enzyme/test/MLIR/ForwardMode/branch.mlir diff --git a/enzyme/test/MLIR/ForwardMode/executeop.mlir b/enzyme/test/MLIR/ForwardMode/executeop.mlir new file mode 100644 index 000000000000..696b07b5f7ea --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/executeop.mlir @@ -0,0 +1,60 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i32) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.execute_region -> (f64, f64) { + cf.switch %c : i32, [ + default: ^bb5, + 104: ^bb3, + 113: ^bb4(%c10 : f64) + ] + ^bb4(%z : f64): // pred: ^bb2 + %x2 = arith.mulf %x, %x : f64 + scf.yield %x2, %z : f64, f64 + ^bb3: + %x3 = arith.addf %x, %x : f64 + scf.yield %x3, %c2 : f64, f64 + ^bb5: + cf.br ^bb4(%x : f64) + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : f64, %c : i32) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i32) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[x:.+]]: f64, %[[dx:.+]]: f64, %[[c:.+]]: i32) -> f64 { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-DAG: %[[cst0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[r0:.+]]:4 = scf.execute_region -> (f64, f64, f64, f64) { +// CHECK-NEXT: %[[cst02:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: cf.switch %[[c]] : i32, [ +// CHECK-NEXT: default: ^bb3, +// CHECK-NEXT: 104: ^bb2, +// CHECK-NEXT: 113: ^bb1(%[[cst10]], %[[cst02]] : f64, f64) +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1(%[[a3:.+]]: f64, %[[da3:.+]]: f64): // 2 preds: ^bb0, ^bb3 +// CHECK-NEXT: %[[a4:.+]] = arith.mulf %[[dx]], %[[x]] : f64 +// CHECK-NEXT: %[[a5:.+]] = arith.mulf %[[dx]], %[[x]] : f64 +// CHECK-NEXT: %[[a6:.+]] = arith.addf %[[a4]], %[[a5]] : f64 +// CHECK-NEXT: %[[a7:.+]] = arith.mulf %[[x]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[a7]], %[[a6]], %[[a3]], %[[da3]] : f64, f64, f64, f64 +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: %[[b8:.+]] = arith.addf %[[dx]], %[[dx]] : f64 +// CHECK-NEXT: %[[b9:.+]] = arith.addf %[[x]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[b9]], %[[b8]], %[[cst2]], %[[cst0]] : f64, f64, f64, f64 +// CHECK-NEXT: ^bb3: // pred: ^bb0 +// CHECK-NEXT: cf.br ^bb1(%[[x]], %[[dx]] : f64, f64) +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#3, %[[r0]]#0 : f64 +// CHECK-NEXT: %[[r3:.+]] = arith.addf %[[r1]], %[[r2]] : f64 +// CHECK-NEXT: %[[r4:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r3]] : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/for.mlir b/enzyme/test/MLIR/ForwardMode/for.mlir similarity index 100% rename from enzyme/test/MLIR/for.mlir rename to enzyme/test/MLIR/ForwardMode/for.mlir diff --git a/enzyme/test/MLIR/ForwardMode/for2.mlir b/enzyme/test/MLIR/ForwardMode/for2.mlir new file mode 100644 index 000000000000..7d4e7608b98e --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/for2.mlir @@ -0,0 +1,30 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %x) -> (f64) { + %n = arith.addf %arg2, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[arg0]], %[[arg4:.+]] = %[[arg1]]) -> (f64, f64) { +// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[i0]]#1 : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/if1.mlir b/enzyme/test/MLIR/ForwardMode/if1.mlir new file mode 100644 index 000000000000..3187bac56fa7 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/if1.mlir @@ -0,0 +1,41 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : f64, %c : i1) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme] } : (f64, f64, i1) -> (f64) + return %r : f64 + } +} + + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64, %[[arg2:.+]]: i1) -> f64 { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, f64, f64) { +// CHECK-NEXT: %[[t3:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[t4:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[t5:.+]] = arith.addf %[[t3]], %[[t4]] : f64 +// CHECK-NEXT: %[[t6:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[t6]], %[[t5]], %[[cst2]] : f64, f64, f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[e3:.+]] = arith.addf %arg1, %arg1 : f64 +// CHECK-NEXT: %[[e4:.+]] = arith.addf %arg0, %arg0 : f64 +// CHECK-NEXT: scf.yield %[[e4]], %[[e3]], %[[cst10]] : f64, f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = arith.mulf %[[r0]]#1, %[[r0]]#2 : f64 +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r1]] : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/inactive.mlir b/enzyme/test/MLIR/ForwardMode/inactive.mlir similarity index 72% rename from enzyme/test/MLIR/inactive.mlir rename to enzyme/test/MLIR/ForwardMode/inactive.mlir index 6ccb8d7033ee..10b7fcd61b27 100644 --- a/enzyme/test/MLIR/inactive.mlir +++ b/enzyme/test/MLIR/ForwardMode/inactive.mlir @@ -1,11 +1,11 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme %s -allow-unregistered-dialect | FileCheck %s module { func.func @inactive(%x : f64) -> f64 { - // We don't have an interface implementation for "func", + // We don't have an interface implementation for "foo", // but we can see it's inactive from its lack of operands // and results. - func.func private @foo() + "test.foo"() : () -> () return %x : f64 } func.func @diff(%x : f64, %dx : f64) -> f64 { @@ -17,4 +17,4 @@ module { // Just check that we didn't trigger the error on there not being an interface // implementation. // CHECK-LABEL: func private @fwddiffeinactive -// CHECK: func private @foo +// CHECK: "test.foo"() diff --git a/enzyme/test/MLIR/invalid.mlir b/enzyme/test/MLIR/ForwardMode/invalid.mlir similarity index 100% rename from enzyme/test/MLIR/invalid.mlir rename to enzyme/test/MLIR/ForwardMode/invalid.mlir diff --git a/enzyme/test/MLIR/llvm.mlir b/enzyme/test/MLIR/ForwardMode/llvm.mlir similarity index 100% rename from enzyme/test/MLIR/llvm.mlir rename to enzyme/test/MLIR/ForwardMode/llvm.mlir diff --git a/enzyme/test/MLIR/memref.mlir b/enzyme/test/MLIR/ForwardMode/memref.mlir similarity index 100% rename from enzyme/test/MLIR/memref.mlir rename to enzyme/test/MLIR/ForwardMode/memref.mlir diff --git a/enzyme/test/MLIR/ForwardMode/multiout.mlir b/enzyme/test/MLIR/ForwardMode/multiout.mlir new file mode 100644 index 000000000000..e42322a7f74f --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/multiout.mlir @@ -0,0 +1,24 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> (f64, f64) { + %y = arith.mulf %x, %x : f64 + return %y, %x : f64, f64 + } + func.func @dsq(%x : f64, %dx : f64) -> (f64, f64) { + %r:2 = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64, f64) + return %r#0, %r#1 : f64, f64 + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { +// CHECK-NEXT: %[[i0:[0-9]+]]:2 = call @fwddiffesquare(%[[arg0]], %[[arg1]]) : (f64, f64) -> (f64, f64) +// CHECK-NEXT: return %[[i0]]#0, %[[i0]]#1 : f64, f64 +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: return %[[i2]], %[[arg1]] : f64, f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/tensorsin.mlir b/enzyme/test/MLIR/ForwardMode/tensorsin.mlir new file mode 100644 index 000000000000..5ab208712b73 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/tensorsin.mlir @@ -0,0 +1,19 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<2xf64>) -> tensor<2xf64> { + %y = math.sin %x : tensor<2xf64> + return %y : tensor<2xf64> + } + func.func @dsq(%x : tensor<2xf64>, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme] } : (tensor<2xf64>, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffesquare(%arg0: tensor<2xf64>, %arg1: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[a0:.+]] = math.cos %arg0 : tensor<2xf64> +// CHECK-NEXT: %[[a1:.+]] = arith.mulf %arg1, %[[a0]] : tensor<2xf64> +// CHECK-NEXT: %[[a2:.+]] = math.sin %arg0 : tensor<2xf64> +// CHECK-NEXT: return %[[a1]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/test.mlir b/enzyme/test/MLIR/ForwardMode/test.mlir similarity index 100% rename from enzyme/test/MLIR/test.mlir rename to enzyme/test/MLIR/ForwardMode/test.mlir diff --git a/enzyme/test/MLIR/ForwardMode/trunc.mlir b/enzyme/test/MLIR/ForwardMode/trunc.mlir new file mode 100644 index 000000000000..8f3918add4ac --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/trunc.mlir @@ -0,0 +1,18 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @f(%x : f64) -> f32 { + %y = arith.truncf %x : f64 to f32 + return %y : f32 + } + func.func @dsq(%x : f64, %dx : f64) -> f32 { + %r = enzyme.fwddiff @f(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f32) + return %r : f32 + } +} + +// CHECK: func.func private @fwddiffef(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f32 { +// CHECK-NEXT: %[[dy:.+]] = arith.truncf %[[arg1]] : f64 to f32 +// CHECK-NEXT: %[[y:.+]] = arith.truncf %[[arg0]] : f64 to f32 +// CHECK-NEXT: return %[[dy]] : f32 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/while.mlir b/enzyme/test/MLIR/ForwardMode/while.mlir new file mode 100644 index 000000000000..cbe7da34769b --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/while.mlir @@ -0,0 +1,44 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @while(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %r:2 = scf.while (%arg1 = %c0, %arg2 = %cst) : (index, f64) -> (index, f64) { + %1 = arith.cmpi slt, %arg1, %c10 : index + scf.condition(%1) %arg1, %arg2 : index, f64 + } do { + ^bb0(%arg1: index, %arg2: f64): + %1 = arith.addi %arg1, %c1 : index + %2 = arith.addf %arg2, %x : f64 + scf.yield %1, %2 : index, f64 + } + return %r#1 : f64 + } + func.func @dwhile(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @while(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } + // CHECK: @fwddiffewhile + // CHECK: (%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 + // CHECK: %[[c0:.+]] = arith.constant 0 : index + // CHECK: %[[c1:.+]] = arith.constant 1 : index + // CHECK: %[[c10:.+]] = arith.constant 10 : index + // CHECK: %[[r0:.+]]:3 = scf.while (%[[arg2:.+]] = %[[c0]], %[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) : (index, f64, f64) -> (index, f64, f64) { + // CHECK: %[[v1:.+]] = arith.cmpi slt, %[[arg2]], %[[c10]] : index + // CHECK: scf.condition(%[[v1]]) %[[arg2]], %[[arg3]], %[[arg4]] : index, f64, f64 + // CHECK: } do { + // CHECK: ^bb0(%[[arg2:.+]]: index, %[[arg3:.+]]: f64, %[[arg4:.+]]: f64): + // CHECK: %[[v1:.+]] = arith.addi %[[arg2]], %[[c1]] : index + // CHECK: %[[v2:.+]] = arith.addf %[[arg4]], %[[arg1]] : f64 + // CHECK: %[[v3:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 + // CHECK: scf.yield %[[v1]], %[[v3]], %[[v2]] : index, f64, f64 + // CHECK: } + // CHECK: return %[[r0]]#2 : f64 + // CHECK: } +} diff --git a/enzyme/test/MLIR/ForwardMode/wrap.mlir b/enzyme/test/MLIR/ForwardMode/wrap.mlir new file mode 100644 index 000000000000..5ff0e1540d13 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/wrap.mlir @@ -0,0 +1,16 @@ +// RUN: %eopt --enzyme-wrap="infn=square outfn=dsq retTy=enzyme_dup argTys=enzyme_dup mode=ForwardMode" %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } +} + +// CHECK: func.func private @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> (f64, f64) { +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[arg0]] : f64 +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : f64 +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: return %[[i3]], %[[i2]] : f64, f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/Passes/dualpush.mlir b/enzyme/test/MLIR/Passes/dualpush.mlir new file mode 100644 index 000000000000..582feddeabff --- /dev/null +++ b/enzyme/test/MLIR/Passes/dualpush.mlir @@ -0,0 +1,48 @@ +// RUN: %eopt -remove-unnecessary-enzyme-ops %s | FileCheck %s + +// This pop cannot be removed even though we know the first popped value with be -1 +// the other pops will be conditional + +module { + func.func private @diffebbargs(%arg0: f64) { + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant 0.000000e+00 : f64 + %3 = "enzyme.init"() : () -> !enzyme.Cache + "enzyme.push"(%3, %c0_i32) : (!enzyme.Cache, i32) -> () + cf.br ^bb1(%arg0 : f64) + ^bb1(%7: f64): // 2 preds: ^bb0, ^bb1 + %8 = arith.cmpf ult, %7, %cst : f64 + "enzyme.push"(%3, %c-1_i32) : (!enzyme.Cache, i32) -> () + cf.cond_br %8, ^bb1(%7 : f64), ^bb4 + ^bb4: // 2 preds: ^bb3, ^bb4 + %18 = "enzyme.pop"(%3) : (!enzyme.Cache) -> i32 + cf.switch %18 : i32, [ + default: ^bb4, + 0: ^bb5 + ] + ^bb5: // pred: ^bb4 + return + } +} + +// CHECK: func.func private @diffebbargs(%arg0: f64) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: "enzyme.push"(%0, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%arg0 : f64) +// CHECK-NEXT: ^bb1(%1: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %2 = arith.cmpf ult, %1, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %2, ^bb1(%1 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: %3 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: cf.switch %3 : i32, [ +// CHECK-NEXT: default: ^bb2, +// CHECK-NEXT: 0: ^bb3 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb3: // pred: ^bb2 +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir index e5bb39eea040..141ff46aaade 100644 --- a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir +++ b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s module { func.func @bbargs(%x: f64) -> f64 { @@ -19,15 +19,38 @@ module { } } -// CHECK: func.func @diff(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { -// CHECK-NEXT: %[[i0:.+]] = call @diffebbargs(%[[arg0]], %[[arg1]]) : (f64, f64) -> f64 -// CHECK-NEXT: return %[[i0:.+]] -// CHECK: func.func private @diffebbargs(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { - -// There should be exactly one block with two f64 args, and their values should be accumulated -// in the shadow. -// CHECK: ^[[BBMULTI:.+]](%[[fst:.+]]: f64, %[[snd:.+]]: f64): -// CHECK-NEXT: "enzyme.set"(%[[shadow:.+]], %[[fst]]) -// CHECK-NEXT: %[[before:.+]] = "enzyme.get"(%[[shadow]]) -// CHECK-NEXT: %[[after:.+]] = arith.addf %[[snd]], %[[before]] -// CHECK-NEXT: "enzyme.set"(%[[shadow]], %[[after]]) +// CHECK: func.func private @diffebbargs(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f64 +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %2 = arith.addf %arg0, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%2 : f64) +// CHECK-NEXT: ^bb1(%3: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %4 = arith.cmpf ult, %3, %cst_0 : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %4, ^bb1(%3 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // pred: ^bb1 +// CHECK-NEXT: %5 = arith.addf %arg1, %cst_0 : f64 +// CHECK-NEXT: %6 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %7 = arith.cmpi eq, %6, %c-1_i32 : i32 +// CHECK-NEXT: %8 = arith.select %7, %5, %cst_0 : f64 +// CHECK-NEXT: %9 = arith.addf %8, %cst_0 : f64 +// CHECK-NEXT: cf.br ^bb3 +// CHECK-NEXT: ^bb3: // 2 preds: ^bb2, ^bb3 +// CHECK-NEXT: %10 = "enzyme.pop"(%1) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %11 = arith.cmpi eq, %10, %c-1_i32 : i32 +// CHECK-NEXT: %12 = arith.select %11, %9, %cst_0 : f64 +// CHECK-NEXT: %13 = arith.addf %12, %cst_0 : f64 +// CHECK-NEXT: cf.switch %10 : i32, [ +// CHECK-NEXT: default: ^bb3, +// CHECK-NEXT: 0: ^bb4 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb4: // pred: ^bb3 +// CHECK-NEXT: %14 = arith.addf %13, %cst_0 : f64 +// CHECK-NEXT: return %14 : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index 5c5596ec389e..9934152def61 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -enzyme-simplify-math -canonicalize %s | FileCheck %s module { func.func @ppow(%x: f64) -> f64 { @@ -19,29 +19,46 @@ module { } } -// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 +// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 { +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %[[one:.+]] = arith.constant 1.0 +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// Make sure the right values are being cached in the primal -// CHECK: %[[one:.+]] = arith.constant 1.0 -// CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -// CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]]) -// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) - -// Ensure the right value is yielded in the adjoint -// CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]]) -// CHECK: %[[dr:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]]) -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]]) -// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) -// CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]]) -// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]]) -// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] -// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) : -// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) -// CHECK-NEXT: %[[dr_next:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: scf.yield %[[dr_next]] -// CHECK: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) -// CHECK-NEXT: return %[[final]] +// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) { +// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[fwd]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 { +// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]] +// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64 +// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) : +// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index +// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) +// CHECK-NEXT: return %[[final]] \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/square.mlir b/enzyme/test/MLIR/ReverseMode/square.mlir new file mode 100644 index 000000000000..4bcae3bb8000 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/square.mlir @@ -0,0 +1,75 @@ +// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN + +module { + func.func @square(%x: f64) -> f64 { + %next = arith.mulf %x, %x : f64 + return %next : f64 + } + + func.func @dsquare(%x: f64, %dr: f64) -> f64 { + %r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme] } : (f64, f64) -> f64 + return %r : f64 + } +} + + +// CHECK: func.func @dsquare(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %0 = call @diffesquare(%arg0, %arg1) : (f64, f64) -> f64 +// CHECK-NEXT: return %0 : f64 +// CHECK-NEXT: } + +// CHECK: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %[[dx:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: %[[c0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[c0]]) : (!enzyme.Gradient, f64) -> () + +// CHECK-NEXT: %[[lhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache + +// CHECK-NEXT: %[[dy:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: %[[c1:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[c1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[rhscache]], %arg0) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[lhscache]], %arg0) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[mul:.+]] = arith.mulf %arg0, %arg0 : f64 +// CHECK-NEXT: cf.br ^bb1 + +// CHECK: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %[[prevdret0:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdret0:.+]] = arith.addf %[[prevdret0]], %arg1 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[postdret0]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[prevdret:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[c2:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[c2]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[postlhs:.+]] = "enzyme.pop"(%[[rhscache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[postrhs:.+]] = "enzyme.pop"(%[[lhscache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dlhs:.+]] = arith.mulf %[[prevdret]], %[[postrhs]] : f64 +// CHECK-NEXT: %[[prevdx1:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdx1:.+]] = arith.addf %[[prevdx1]], %[[dlhs]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[postdx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[drhs:.+]] = arith.mulf %[[prevdret]], %[[postlhs]] : f64 +// CHECK-NEXT: %[[prevdx2:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdx2:.+]] = arith.addf %[[prevdx2]], %[[drhs]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[postdx2]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[res:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: return %[[res]] : f64 +// CHECK-NEXT: } + + +// REM: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// REM-NEXT: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[a1:.+]] = arith.addf %arg1, %[[cst]] : f64 +// REM-NEXT: %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a3:.+]] = arith.addf %[[a2]], %[[cst]] : f64 +// REM-NEXT: %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64 +// REM-NEXT: return %[[a5]] : f64 +// REM-NEXT: } + +// FIN: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// FIN-NEXT: %0 = arith.mulf %arg1, %arg0 : f64 +// FIN-NEXT: %1 = arith.addf %0, %0 : f64 +// FIN-NEXT: return %1 : f64 +// FIN-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/trunc.mlir b/enzyme/test/MLIR/ReverseMode/trunc.mlir new file mode 100644 index 000000000000..c078cb0634ae --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/trunc.mlir @@ -0,0 +1,18 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN + +module { + func.func @f(%x: f64) -> f32 { + %next = arith.truncf %x : f64 to f32 + return %next : f32 + } + + func.func @dsquare(%x: f64, %dr: f32) -> f64 { + %r = enzyme.autodiff @f(%x, %dr) { activity=[#enzyme] } : (f64, f32) -> f64 + return %r : f64 + } +} + +// FIN: func.func private @diffef(%[[x:.+]]: f64, %[[dx:.+]]: f32) -> f64 { +// FIN-NEXT: %[[res:.+]] = arith.extf %[[dx]] : f32 to f64 +// FIN-NEXT: return %[[res]] : f64 +// FIN-NEXT: } diff --git a/enzyme/test/TypeAnalysis/smax.ll b/enzyme/test/TypeAnalysis/smax.ll new file mode 100644 index 000000000000..b050ccc9ef18 --- /dev/null +++ b/enzyme/test/TypeAnalysis/smax.ll @@ -0,0 +1,26 @@ +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=smax -o /dev/null | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=smax -S -o /dev/null | FileCheck %s; fi + +define i32 @smax(i32 %a, i32 %b) { +entry: + %0 = call i32 @llvm.smax.i32(i32 %a, i32 %b) + %1 = call i32 @getint() + %2 = call i32 @getint() + %3 = call i32 @llvm.smax.i32(i32 %1, i32 %2) + ret i32 %3 +} + +declare i32 @llvm.smax.i32(i32, i32) + +declare i32 @getint() + + +; CHECK: smax - {[-1]:Integer} |{[-1]:Integer}:{} {[-1]:Integer}:{} +; CHECK-NEXT: i32 %a: {[-1]:Integer} +; CHECK-NEXT: i32 %b: {[-1]:Integer} +; CHECK-NEXT: entry +; CHECK-NEXT: %0 = call i32 @llvm.smax.i32(i32 %a, i32 %b): {[-1]:Integer} +; CHECK-NEXT: %1 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %2 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %3 = call i32 @llvm.smax.i32(i32 %1, i32 %2): {[-1]:Integer} +; CHECK-NEXT: ret i32 %3: {} diff --git a/enzyme/test/TypeAnalysis/smax0.ll b/enzyme/test/TypeAnalysis/smax0.ll new file mode 100644 index 000000000000..de79bcde70bb --- /dev/null +++ b/enzyme/test/TypeAnalysis/smax0.ll @@ -0,0 +1,23 @@ +; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=smax0 -o /dev/null | FileCheck %s; fi +; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="print-type-analysis" -type-analysis-func=smax0 -S -o /dev/null | FileCheck %s; fi + +define i32 @smax0(i32 %a, i32 %b) { +entry: + %0 = call i32 @llvm.smax.i32(i32 %a, i32 0) + %1 = call i32 @getint() + %2 = call i32 @llvm.smax.i32(i32 %1, i32 0) + ret i32 %2 +} + +declare i32 @llvm.smax.i32(i32, i32) + +declare i32 @getint() + +; CHECK: smax0 - {[-1]:Integer} |{[-1]:Integer}:{} {[-1]:Integer}:{} +; CHECK-NEXT: i32 %a: {[-1]:Integer} +; CHECK-NEXT: i32 %b: {[-1]:Integer} +; CHECK-NEXT: entry +; CHECK-NEXT: %0 = call i32 @llvm.smax.i32(i32 %a, i32 0): {[-1]:Integer} +; CHECK-NEXT: %1 = call i32 @getint(): {[-1]:Integer} +; CHECK-NEXT: %2 = call i32 @llvm.smax.i32(i32 %1, i32 0): {[-1]:Integer} +; CHECK-NEXT: ret i32 %2: {} diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 481fce924bcc..0cc5e6f28f38 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -16,6 +16,10 @@ config.llvm_shlib_ext = "@LLVM_SHLIBEXT@" config.targets_to_build = "@TARGETS_TO_BUILD@" +has_mpfr_h = "@HAS_MPFR_H@" +mpfr_lib_path = "@MPFR_LIB_PATH@" +has_mpfr = "yes" if mpfr_lib_path != "MPFR_LIB_PATH-NOTFOUND" and has_mpfr_h == "1" else "no" + ## Check the current platform with regex import re EAT_ERR_ON_X86 = ' ' @@ -82,7 +86,7 @@ if len("@ENZYME_BINARY_DIR@") == 0: oldPMOP = oldPM newPMOP = newPM -if int(config.llvm_ver) >= 16: +if int(config.llvm_ver) == 16: newPM += " -opaque-pointers=0" oldPM += " -opaque-pointers=0" @@ -112,6 +116,8 @@ if len("@ENZYME_BINARY_DIR@") == 0: config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < 15 else newPM)) config.substitutions.append(('%newLoadClangEnzyme', newPM)) +config.substitutions.append(('%hasMPFR', has_mpfr)) + # Let the main config do the real work. cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" if len("@ENZYME_SOURCE_DIR@") == 0: diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index c0e725eaf948..3689c77d4d74 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -146,14 +146,22 @@ void emit_handleBLAS(ArrayRef blasPatterns, raw_ostream &os) { << " llvm::errs() << \" fallback?\\n\"; \n" << " return false; \n" << " } \n" - << " } \n" - << " \n" - << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" - << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" << " } else { \n" - << " eraseIfUnused(call); \n" - << " } \n" - << " \n" + << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" + << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" + << " } else { \n" + << " eraseIfUnused(call); \n" + << " } \n" + << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n" + << " gutils->knownRecomputeHeuristic.end()) {\n" + << " if (!gutils->knownRecomputeHeuristic[&call]) {\n" + << " auto newCall = gutils->getNewFromOriginal(&call);\n" + << " llvm::IRBuilder<> BuilderZ(newCall);\n" + << " gutils->cacheForReverse(BuilderZ, newCall,\n" + << " getIndex(&call, CacheType::Self, BuilderZ));\n" + << " }\n" + << " }\n" + << " }\n" << " return result; \n" << "} \n"; } @@ -229,10 +237,16 @@ void emit_free_and_ending(const TGPattern &pattern, raw_ostream &os) { os << " }\n" << " }\n" + << " \n" + << " if (Mode == DerivativeMode::ReverseModeGradient) { \n" + << " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n" + << " } else { \n" + << " eraseIfUnused(call); \n" + << " } \n" << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n" << " gutils->knownRecomputeHeuristic.end()) {\n" << " if (!gutils->knownRecomputeHeuristic[&call]) {\n" - << " gutils->cacheForReverse(BuilderZ, newCall,\n" + << " auto cv = gutils->cacheForReverse(BuilderZ, newCall,\n" << " getIndex(&call, CacheType::Self, BuilderZ));\n" << " }\n" << " }\n" @@ -430,6 +444,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) { break; } } + (void)hasInt; assert(hasInt); os << " Type* cublas_retty = nullptr;\n" @@ -468,6 +483,7 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) { break; } } + (void)foundInt; assert(foundInt && "no int type found in blas call"); os << " // fpType already given by blas type (s, d, c, z) \n" @@ -855,7 +871,7 @@ void emit_fwd_rewrite_rules(const TGPattern &pattern, raw_ostream &os) { if (ty == ArgType::fp) { const auto name = nameVec[inputType.first]; os << " Value *d_" << name - << " = llvm::ConstantFP::get(fpType, 0.0);\n"; + << " = Constant::getNullValue(gutils->getShadowType(fpType));\n"; } } diff --git a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h index aef73e479f2f..aebf6d7aec0d 100644 --- a/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDeclUpdater.h @@ -195,17 +195,10 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) { os << " auto name = getFuncName(&F);\n"; os << " auto changed = false;\n"; os << " auto blasMetaData = extractBLAS(name);\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " if (F.empty() && blasMetaData.has_value()) {\n"; - os << " attributeBLAS(blasMetaData.value(), &F);\n"; - os << " changed = true;\n"; - os << " }\n"; - os << " #else\n"; - os << " if (F.empty() && blasMetaData.hasValue()) {\n"; - os << " attributeBLAS(blasMetaData.getValue(), &F);\n"; - os << " changed = true;\n"; - os << " }\n"; - os << " #endif\n"; + os << " if (F.empty() && blasMetaData) {\n"; + os << " attributeBLAS(*blasMetaData, &F);\n"; + os << " changed = true;\n"; + os << " }\n"; { const auto &patterns = RK.getAllDerivedDefinitions("CallPattern"); for (Record *pattern : patterns) { diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index b3d64f2c7ab9..51e883e3b804 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -182,11 +182,7 @@ void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { emitDiffUse(RK, os, CallDerivatives); os << " auto blasMetaData = extractBLAS(funcName);\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " if (blasMetaData.has_value())\n"; - os << " #else\n"; - os << " if (blasMetaData.hasValue())\n"; - os << " #endif\n"; + os << " if (blasMetaData)\n"; os << " {\n"; os << " auto Mode = gutils->mode;\n"; os << " const bool cacheMode = (Mode != DerivativeMode::ForwardMode);\n"; @@ -198,11 +194,7 @@ void emitBlasDiffUse(const RecordKeeper &RK, llvm::raw_ostream &os) { os << " assert(found != gutils->overwritten_args_map_ptr->end());\n"; os << " overwritten_args_ptr = &found->second;\n"; os << " }\n"; - os << " #if LLVM_VERSION_MAJOR >= 16\n"; - os << " BlasInfo blas = blasMetaData.value();\n"; - os << " #else\n"; - os << " BlasInfo blas = blasMetaData.getValue();\n"; - os << " #endif\n"; + os << " BlasInfo blas = *blasMetaData;\n"; for (auto &&newPattern : newBlasPatterns) { emit_BLASDiffUse(newPattern, os); } diff --git a/enzyme/tools/enzyme-tblgen/caching.cpp b/enzyme/tools/enzyme-tblgen/caching.cpp index cdb7b3a92910..7a98d944b8ab 100644 --- a/enzyme/tools/enzyme-tblgen/caching.cpp +++ b/enzyme/tools/enzyme-tblgen/caching.cpp @@ -293,7 +293,7 @@ os << " if (byRef) valueTypes[" << len_pos << "] = ValueType::Primal;\n"; os << " if (EnzymeLapackCopy) {\n" << " Value *uplo = llvm::ConstantInt::get(charTy, 0);\n" // garbage data, just should not match U or L << " uplo = to_blas_callconv(BuilderZ, uplo, byRef, cublas, nullptr, allocationBuilder, \"copy.garbage\");\n" -<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, N};\n" +<< " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" << " if (!byRef) {\n" << " args.insert(args.begin(), arg_layout); valueTypes.insert(valueTypes.begin(), ValueType::Primal); }\n" << " callMemcpyStridedLapack(BuilderZ, *gutils->oldFunc->getParent(), blas, args, gutils->getInvertedBundles(&call, valueTypes, BuilderZ, /*lookup*/false));\n" diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 11384a2284e8..6eff540b6622 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -65,7 +65,9 @@ static cl::opt cl::values(clEnumValN(MLIRDerivatives, "gen-mlir-derivatives", "Generate MLIR derivative")), cl::values(clEnumValN(CallDerivatives, "gen-call-derivatives", - "Generate call derivative"))); + "Generate call derivative")), + cl::values(clEnumValN(GenHeaderVariables, "gen-header-strings", + "Generate header strings"))); void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, StringRef FT, StringRef cconv, Init *func, @@ -119,6 +121,21 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, << ")->getCallingConv();\n"; return; } + if (opName == "ArgAsRetTypesFunc" || + Def->isSubClassOf("ArgAsRetTypesFunc")) { + os << curIndent << "auto " << FT << "_old = cast(&" << origName + << ")->getFunctionType();\n"; + os << curIndent << "auto " << FT << " = FunctionType::get(" << FT + << "_old->params()[0], " << FT << "_old->params(), " << FT + << "_old->isVarArg());\n"; + os << curIndent << "auto " << callval + << " = gutils->oldFunc->getParent()->getOrInsertFunction("; + os << Def->getValueInit("name")->getAsString(); + os << ", " << FT << ", called->getAttributes()).getCallee();\n"; + os << curIndent << "auto " << cconv << " = cast(&" << origName + << ")->getCallingConv();\n"; + return; + } } assert(0 && "Unhandled function"); } @@ -337,7 +354,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("unknown named operand in typeof") + resultTree->getAsString()); - os << "->getType()"; + if (intrinsic == MLIRDerivatives) + os << ".getType()"; + else + os << "->getType()"; return false; } else if (opName == "VectorSize" || Def->isSubClassOf("VectorSize")) { if (resultRoot->getNumArgs() != 1) @@ -370,7 +390,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << "({\n"; os << curIndent << INDENT << "// Computing SelectIfActive\n"; - os << curIndent << INDENT << "Value *imVal = nullptr;\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value imVal = nullptr;\n"; + else + os << curIndent << INDENT << "llvm::Value *imVal = nullptr;\n"; os << curIndent << INDENT << "if (!gutils->isConstantValue("; @@ -400,7 +423,7 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, retidx, origName, newFromOriginal, intrinsic); os << ";\n"; - if (!vector) { + if (!vector && intrinsic != MLIRDerivatives) { os << curIndent << INDENT << INDENT << "llvm::Value* vec_imVal = gutils->getWidth() == 1 ? imVal : " "UndefValue::get(gutils->getShadowType(imVal" @@ -425,26 +448,52 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, os << curIndent << "})"; return true; } else if (opName == "ConstantFP" || Def->isSubClassOf("ConstantFP")) { - if (resultRoot->getNumArgs() != 1) - PrintFatalError(pattern->getLoc(), - "only single op constantfp supported"); - auto value = dyn_cast(Def->getValueInit("value")); if (!value) PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") + resultTree->getAsString()); - os << "ConstantFP::get("; - if (resultRoot->getArgName(0)) { - auto name = resultRoot->getArgName(0)->getAsUnquotedString(); - auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); - assert(!isVec); - os << ord; - } else - PrintFatalError(pattern->getLoc(), - Twine("unknown named operand in constantfp") + - resultTree->getAsString()); - os << "->getType(), \"" << value->getValue() << "\")"; + if (intrinsic == MLIRDerivatives) { + if (resultRoot->getNumArgs() > 1) + PrintFatalError(pattern->getLoc(), + "only zero or single op constantfp supported"); + os << builder << ".create<" + << cast(Def->getValueInit("dialect"))->getValue() + << "::" << cast(Def->getValueInit("opName"))->getValue() + << ">(op.getLoc(), "; + std::string ord; + if (resultRoot->getNumArgs() == 0) { + ord = "op->getResult(0)"; + } else { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord1, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); + assert(!isVec); + ord = ord1; + } + os << ord << ".getType(), "; + auto typeCast = + dyn_cast(Def->getValueInit("type"))->getValue(); + if (typeCast != "") + os << "(" << typeCast << ")"; + os << "mlir::enzyme::getConstantAttr(" << ord << ".getType(), "; + os << "\"" << value->getValue() << "\"))"; + } else { + if (resultRoot->getNumArgs() != 1) + PrintFatalError(pattern->getLoc(), + "only single op constantfp supported"); + + os << "ConstantFP::get("; + if (resultRoot->getArgName(0)) { + auto name = resultRoot->getArgName(0)->getAsUnquotedString(); + auto [ord, isVec] = nameToOrdinal.lookup(name, pattern, resultTree); + assert(!isVec); + os << ord; + } else + PrintFatalError(pattern->getLoc(), + Twine("unknown named operand in constantfp") + + resultTree->getAsString()); + os << "->getType(), \"" << value->getValue() << "\")"; + } return false; } else if (opName == "Zero" || Def->isSubClassOf("Zero")) { if (resultRoot->getNumArgs() != 1) @@ -495,6 +544,13 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, "{(llvm::Constant*)ConstantFP::get(ST->getElementType(0), \"" << rvalue->getValue() << "\"), (llvm::Constant*)ConstantFP::get(ST->getElementType(1), \"" + << ivalue->getValue() << "\")});\n" + << "} else if (auto AT = dyn_cast(ty)) {\n" + << curIndent << INDENT << INDENT + << "ret = ConstantArray::get(AT, " + "{(llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" + << rvalue->getValue() + << "\"), (llvm::Constant*)ConstantFP::get(AT->getElementType(), \"" << ivalue->getValue() << "\")});\n"; os << curIndent << INDENT << "} else assert(0 && \"unhandled cfp\");\n"; os << curIndent << INDENT << "ret;\n"; @@ -833,7 +889,9 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, } else if (opName == "CheckedDiv") { os << "checkedDiv(" << builder << ", "; } else if (intrinsic == MLIRDerivatives) { - os << builder << ".create<" << opName << ">(op.getLoc(), "; + auto dialect = Def->getValueAsString("dialect"); + os << builder << ".create<" << dialect << "::" << opName + << ">(op.getLoc(), "; } else { os << builder << ".Create" << opName << "("; } @@ -954,6 +1012,13 @@ void handleUse( foundDiffRet = true; return; } + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + return; + } + if (!Def->isSubClassOf("Operation")) { + errs() << *resultTree << "\n"; + errs() << opName << " " << *Def << "\n"; + } assert(Def->isSubClassOf("Operation")); bool usesPrimal = Def->getValueAsBit("usesPrimal"); bool usesShadow = Def->getValueAsBit("usesShadow"); @@ -1188,6 +1253,316 @@ void printDiffUse( } } +static void emitHeaderIncludes(const RecordKeeper &recordKeeper, + raw_ostream &os) { + const auto &patterns = recordKeeper.getAllDerivedDefinitions("Headers"); + os << "const char* include_headers[][2] = {\n"; + bool seen = false; + for (Record *pattern : patterns) { + if (seen) + os << ",\n"; + auto filename = pattern->getValueAsString("filename"); + auto contents = pattern->getValueAsString("contents"); + os << "{\"" << filename << "\"\n,"; + os << "R\"(" << contents << ")\"\n"; + os << "}"; + seen = true; + } + os << "};\n"; +} + +static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree, + ActionType intrinsic, StringRef origName, + ListInit *argOps) { + + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "RevDerivative : \n"; + os << " public " + "ReverseAutoDiffOpInterface::ExternalModel<" + << opName << "RevDerivative, " << dialect << "::" << opName << "> {\n"; + os << " SmallVector cachedArguments(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " SmallVector toret(op->getNumOperands(), false);\n"; + StringMap> varNameToCondition; + + std::function)> insert = + [&](DagInit *ptree, ArrayRef prev) { + for (auto treeEn : llvm::enumerate(ptree->getArgs())) { + auto tree = treeEn.value(); + auto name = ptree->getArgNameStr(treeEn.index()); + SmallVector next(prev.begin(), prev.end()); + next.push_back(treeEn.index()); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (name.size()) { + varNameToCondition[name] = std::make_tuple( + "idx == " + std::to_string(treeEn.index()), "", false); + } + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + varNameToCondition[tree->getNameStr()] = + std::make_tuple("ILLEGAL", "ILLEGAL", false); + + os << " for (size_t idx=0; idxgetNumOperands(); idx++) {\n"; + os << " bool used = false;\n"; + printDiffUse(os, " ", argOps, origName, intrinsic, tree, + varNameToCondition); + os << " toret[idx] = used;\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + + os << " SmallVector cacheValues(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " if (gutils->isConstantInstruction(op) || " + "gutils->isConstantValue(op->getResult(0))) return {};\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " SmallVector toret;\n"; + os << " OpBuilder builder(gutils->getNewFromOriginal(op));\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " Value cache = " + "gutils->initAndPushCache(gutils->getNewFromOriginal(op->" + "getOperand(en.index())), builder);\n"; + os << " toret.push_back(cache);\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + os << "\n"; + os << " void createShadowValues(Operation *op, OpBuilder &builder,\n"; + os << " MGradientUtilsReverse *gutils) const " + "{}\n"; + + os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " + "&builder,\n"; + os << " MGradientUtilsReverse *gutils,\n"; + os << " SmallVector caches) const {\n"; + os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; + os << " mlir::Value dif = nullptr;\n"; +} + +static VariableSetting parseVariables(DagInit *tree, ActionType intrinsic, + StringRef origName) { + VariableSetting nameToOrdinal; + std::function)> insert = + [&](DagInit *ptree, ArrayRef prev) { + unsigned i = 0; + for (auto tree : ptree->getArgs()) { + SmallVector next(prev.begin(), prev.end()); + next.push_back(i); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (ptree->getArgNameStr(i).size()) { + std::string op; + if (intrinsic != MLIRDerivatives) + op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + else + op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); + if (prev.size() > 0) { + op = "gutils->extractMeta(Builder2, " + op + + ", ArrayRef({"; + bool first = true; + for (unsigned i = 1; i < next.size(); i++) { + if (!first) + op += ", "; + op += std::to_string(next[i]); + } + op += "}))"; + } + nameToOrdinal.insert(ptree->getArgNameStr(i), op, false); + } + i++; + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + nameToOrdinal.insert(tree->getNameStr(), + (Twine("(&") + origName + ")").str(), false); + return nameToOrdinal; +} + +static void emitReverseCommon(raw_ostream &os, Record *pattern, DagInit *tree, + ActionType intrinsic, StringRef origName, + ListInit *argOps) { + auto nameToOrdinal = parseVariables(tree, intrinsic, origName); + + bool seen = false; + for (auto argOpEn : enumerate(*argOps)) { + size_t argIdx = argOpEn.index(); + if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + if (Def->getValueAsBit("asserting")) + os << " assert(gutils->isConstantValue(" << origName << ".getOperand(" + << argIdx << ")));\n"; + continue; + } + } + + os << " "; + if (seen) + os << "} else "; + seen = true; + if (intrinsic == MLIRDerivatives) { + os << "if (!dif && !gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + } else { + os << "if (!dif && !gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + } + DagInit *resultTree = cast(argOpEn.value()); + if (hasDiffeRet(resultTree)) { + if (intrinsic == MLIRDerivatives) { + os << " dif = gutils->diffe(" << origName << ", builder);\n"; + os << " gutils->zeroDiffe(" << origName << ", builder);\n"; + } else { + os << " dif = diffe(&" << origName << ", Builder2);\n"; + os << " setDiffe(&" << origName + << ", " + "Constant::getNullValue(gutils->getShadowType(" + << origName + << ".getType())), " + "Builder2);\n"; + } + } + } + if (seen) + os << " }\n"; + + if (intrinsic == MLIRDerivatives) { + os << " SmallVector operands(op->getNumOperands(), nullptr);\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " size_t count = 0;\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " operands[en.index()] = " + "gutils->popCache(caches[count], builder);\n"; + os << " count++;\n"; + os << " }\n"; + } + + std::function, Init *)> revres = + [&](size_t argIdx, ArrayRef idx, Init *ival) { + if (DagInit *resultTree = dyn_cast(ival)) { + auto Def = cast(resultTree->getOperator())->getDef(); + if (Def->isSubClassOf("MultiReturn")) { + unsigned i = 0; + for (auto r : resultTree->getArgs()) { + SmallVector next(idx.begin(), idx.end()); + next.push_back(i); + revres(argIdx, next, r); + i++; + } + return; + } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } + const char *curIndent = " "; + os << curIndent << "{\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value tmp = "; + else + os << curIndent << INDENT << "Value *tmp = "; + bool vectorValued = handle( + Twine(curIndent) + INDENT, "revarg", os, pattern, resultTree, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", + nameToOrdinal, /*lookup*/ true, idx, origName, + /*newFromOriginal*/ true, intrinsic); + os << ";\n"; + + if (intrinsic == MLIRDerivatives) { + os << "assert(toadd == nullptr); toadd = tmp;\n"; + } else { + os << curIndent << INDENT + << "Value *out = " + "UndefValue::get(gutils->getShadowType(" + << origName << ".getOperand(" << argIdx << ")->getType()));\n"; + + os << curIndent << INDENT + << "for(unsigned int idx=0, W=gutils->getWidth(); " + "idxgetWidth() == " + "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " + "nullptr;\n"; + os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; + if (vectorValued) + os << curIndent << INDENT << INDENT + << "if (gutils->getWidth() > 1) next = " + "gutils->extractMeta(Builder2, next, idx);\n"; + os << curIndent << INDENT << INDENT + << "if (prev) next = Builder2.CreateFAdd(prev, " + "next);\n"; + os << curIndent << INDENT << INDENT + << "out = (gutils->getWidth() > 1) ? " + "Builder2.CreateInsertValue(out, next, idx) : next;\n"; + os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "toadd = out;\n"; + } + os << curIndent << "}\n"; + + } else if (ListInit *lst = dyn_cast(ival)) { + unsigned i = 0; + for (auto elem : *lst) { + SmallVector next(idx.begin(), idx.end()); + next.push_back(i); + revres(argIdx, next, elem); + i++; + } + } else + assert(0); + }; + + for (auto argOpEn : enumerate(*argOps)) { + size_t argIdx = argOpEn.index(); + if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + continue; + } + } + + const char *curIndent = " "; + if (intrinsic == MLIRDerivatives) + os << curIndent << "if (!gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + else + os << curIndent << "if (!gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value toadd = nullptr;\n"; + else + os << curIndent << INDENT << "Value *toadd = nullptr;\n"; + revres(argIdx, {}, argOpEn.value()); + + if (intrinsic == MLIRDerivatives) { + os << curIndent << INDENT << "if (toadd) gutils->addToDiffe(" << origName + << "->getOperand(" << argIdx << "), toadd, builder);\n"; + } else { + os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName + << ".getOperand(" << argIdx << "), toadd"; + os << ", Builder2, " << origName << ".getOperand(" << argIdx + << ")->getType());\n"; + } + os << curIndent << "}\n"; + } +} static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); @@ -1208,6 +1583,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case BinopDerivatives: patternNames = "BinopPattern"; break; + case GenHeaderVariables: case GenBlasDerivatives: case UpdateBlasDecl: case UpdateBlasTA: @@ -1219,9 +1595,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, for (Record *pattern : patterns) { DagInit *tree = pattern->getValueAsDag("PatternToMatch"); - DagInit *duals = nullptr; - if (intrinsic != MLIRDerivatives) - duals = pattern->getValueAsDag("ArgDuals"); + DagInit *duals = pattern->getValueAsDag("ArgDuals"); + assert(duals); // Emit RewritePattern for Pattern. ListInit *argOps = pattern->getValueAsListInit("ArgDerivatives"); @@ -1240,6 +1615,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case MLIRDerivatives: { auto opName = pattern->getValueAsString("opName"); @@ -1281,13 +1657,13 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, StringRef name = cast(lst->getValues()[0])->getValue(); if (lst->size() >= 2) { auto min = cast(lst->getValues()[1])->getValue(); - int min_int; + int min_int = 100000; min.getAsInteger(10, min_int); if (min.size() != 0 && LLVM_VERSION_MAJOR < min_int) continue; if (lst->size() >= 3) { auto max = cast(lst->getValues()[2])->getValue(); - int max_int; + int max_int = 0; max.getAsInteger(10, max_int); if (max.size() != 0 && LLVM_VERSION_MAJOR > max_int) continue; @@ -1386,42 +1762,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } } - VariableSetting nameToOrdinal; - - std::function)> insert = - [&](DagInit *ptree, ArrayRef prev) { - unsigned i = 0; - for (auto tree : ptree->getArgs()) { - SmallVector next(prev.begin(), prev.end()); - next.push_back(i); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (ptree->getArgNameStr(i).size()) { - auto op = - (origName + ".getOperand(" + Twine(next[0]) + ")").str(); - if (prev.size() > 0) { - op = "gutils->extractMeta(Builder2, " + op + - ", ArrayRef({"; - bool first = true; - for (unsigned i = 1; i < next.size(); i++) { - if (!first) - op += ", "; - op += std::to_string(next[i]); - } - op += "}))"; - } - nameToOrdinal.insert(ptree->getArgNameStr(i), op, false); - } - i++; - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - nameToOrdinal.insert(tree->getNameStr(), - (Twine("(&") + origName + ")").str(), false); + VariableSetting nameToOrdinal = parseVariables(tree, intrinsic, origName); if (intrinsic != BinopDerivatives && intrinsic != InstDerivatives && intrinsic != MLIRDerivatives) { @@ -1461,8 +1802,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } // TODO - if (!duals || - duals->getOperator()->getAsString() == + if (duals->getOperator()->getAsString() == "ForwardFromSummedReverseInternal" || cast(duals->getOperator()) ->getDef() @@ -1524,6 +1864,9 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } return; } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } os << curIndent << INDENT << "{\n"; if (intrinsic == MLIRDerivatives) os << curIndent << INDENT << INDENT << "mlir::Value itmp = "; @@ -1586,10 +1929,15 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } } else { - os << " Value *res = "; + if (intrinsic == MLIRDerivatives) { + os << " mlir::Value res = "; + } else { + os << " Value *res = "; + } ArrayRef retidx{}; bool vectorValued = - handle(" ", "fwdnsrarg", os, pattern, duals, "Builder2", + handle(" ", "fwdnsrarg", os, pattern, duals, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", nameToOrdinal, /*lookup*/ false, retidx, origName, /*newFromOriginal*/ true, intrinsic); (void)vectorValued; @@ -1615,251 +1963,30 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " Value *dif = nullptr;\n"; } else { os << "};\n"; - auto opName = pattern->getValueAsString("opName"); - auto dialect = pattern->getValueAsString("dialect"); - os << "struct " << opName << "RevDerivative : \n"; - os << " public " - "ReverseAutoDiffOpInterface::ExternalModel<" - << opName << "RevDerivative, " << dialect << "::" << opName << "> {\n"; - os << " SmallVector cachedArguments(Operation *op,\n"; - os << " MGradientUtilsReverse *gutils) " - "const {\n"; - os << " SmallVector toret(op->getNumOperands(), false);\n"; - StringMap> varNameToCondition; - - std::function)> insert = - [&](DagInit *ptree, ArrayRef prev) { - for (auto treeEn : llvm::enumerate(ptree->getArgs())) { - auto tree = treeEn.value(); - auto name = ptree->getArgNameStr(treeEn.index()); - SmallVector next(prev.begin(), prev.end()); - next.push_back(treeEn.index()); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (name.size()) { - varNameToCondition[name] = std::make_tuple( - "idx == " + std::to_string(treeEn.index()), "", false); - } - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - varNameToCondition[tree->getNameStr()] = - std::make_tuple("ILLEGAL", "ILLEGAL", false); - - os << " for (size_t idx=0; idxgetNumOperands(); idx++) {\n"; - os << " bool used = false;\n"; - printDiffUse(os, " ", argOps, origName, intrinsic, tree, - varNameToCondition); - os << " toret[idx] = used;\n"; - os << " }\n"; - os << " return toret;\n"; - os << " }\n"; - - os << " SmallVector cacheValues(Operation *op,\n"; - os << " MGradientUtilsReverse *gutils) " - "const {\n"; - os << " if (gutils->isConstantInstruction(op) || " - "gutils->isConstantValue(op->getResult(0))) return {};\n"; - os << " auto neededArgs = cachedArguments(op, gutils);\n"; - os << " SmallVector toret;\n"; - os << " OpBuilder builder(gutils->getNewFromOriginal(op));\n"; - os << " for (auto en : llvm::enumerate(neededArgs))\n"; - os << " if (en.value()) {\n"; - os << " Value cache = " - "gutils->initAndPushCache(gutils->getNewFromOriginal(op->" - "getOperand(en.index())), builder);\n"; - os << " toret.push_back(cache);\n"; - os << " }\n"; - os << " return toret;\n"; - os << " }\n"; - os << "\n"; - os << " void createShadowValues(Operation *op, OpBuilder &builder,\n"; - os << " MGradientUtilsReverse *gutils) const " - "{}\n"; - - os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " - "&builder,\n"; - os << " MGradientUtilsReverse *gutils,\n"; - os << " SmallVector caches) const {\n"; - os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; - os << " mlir::Value dif = nullptr;\n"; + emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); } - // TODO vector - - bool seen = false; - for (auto argOpEn : enumerate(*argOps)) { - size_t argIdx = argOpEn.index(); - if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { - auto opName = resultRoot->getOperator()->getAsString(); - auto Def = cast(resultRoot->getOperator())->getDef(); - if (opName == "InactiveArgSpec" || - Def->isSubClassOf("InactiveArgSpec")) { - if (Def->getValueAsBit("asserting")) - os << " assert(gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << ")));\n"; - continue; - } - } - - os << " "; - if (seen) - os << "} else "; - seen = true; - if (intrinsic == MLIRDerivatives) { - os << "if (!dif && !gutils->isConstantValue(" << origName - << "->getOperand(" << argIdx << "))) {\n"; - } else { - os << "if (!dif && !gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; - } - DagInit *resultTree = cast(argOpEn.value()); - if (hasDiffeRet(resultTree)) { - if (intrinsic == MLIRDerivatives) { - os << " dif = gutils->diffe(" << origName << ", builder);\n"; - os << " gutils->clearValue(" << origName << ", builder);\n"; - } else { - os << " dif = diffe(&" << origName << ", Builder2);\n"; - os << " setDiffe(&" << origName - << ", " - "Constant::getNullValue(gutils->getShadowType(" - << origName - << ".getType())), " - "Builder2);\n"; - } - } - } - if (seen) - os << " }\n"; - - if (intrinsic == MLIRDerivatives) { - os << " SmallVector operands(op->getNumOperands(), nullptr);\n"; - os << " auto neededArgs = cachedArguments(op, gutils);\n"; - os << " size_t count = 0;\n"; - os << " for (auto en : llvm::enumerate(neededArgs))\n"; - os << " if (en.value()) {\n"; - os << " operands[en.index()] = " - "gutils->popCache(caches[count], builder);\n"; - os << " count++;\n"; - os << " }\n"; - } - - std::function, Init *)> revres = - [&](size_t argIdx, ArrayRef idx, Init *ival) { - if (DagInit *resultTree = dyn_cast(ival)) { - auto Def = cast(resultTree->getOperator())->getDef(); - if (Def->isSubClassOf("MultiReturn")) { - unsigned i = 0; - for (auto r : resultTree->getArgs()) { - SmallVector next(idx.begin(), idx.end()); - next.push_back(i); - revres(argIdx, next, r); - i++; - } - return; - } - const char *curIndent = " "; - os << curIndent << "{\n"; - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value tmp = "; - else - os << curIndent << INDENT << "Value *tmp = "; - bool vectorValued = handle( - Twine(curIndent) + INDENT, "revarg", os, pattern, resultTree, - (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", - nameToOrdinal, /*lookup*/ true, idx, origName, - /*newFromOriginal*/ true, intrinsic); - os << ";\n"; - - if (intrinsic == MLIRDerivatives) { - os << "assert(toadd == nullptr); toadd = tmp;\n"; - } else { - os << curIndent << INDENT - << "Value *out = " - "UndefValue::get(gutils->getShadowType(" - << origName << ".getOperand(" << argIdx << ")->getType()));\n"; - - os << curIndent << INDENT - << "for(unsigned int idx=0, W=gutils->getWidth(); " - "idxgetWidth() == " - "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " - "nullptr;\n"; - os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; - if (vectorValued) - os << curIndent << INDENT << INDENT - << "if (gutils->getWidth() > 1) next = " - "gutils->extractMeta(Builder2, next, idx);\n"; - os << curIndent << INDENT << INDENT - << "if (prev) next = Builder2.CreateFAdd(prev, " - "next);\n"; - os << curIndent << INDENT << INDENT - << "out = (gutils->getWidth() > 1) ? " - "Builder2.CreateInsertValue(out, next, idx) : next;\n"; - os << curIndent << INDENT << "}\n"; - os << curIndent << INDENT << "toadd = out;\n"; - } - os << curIndent << "}\n"; - - } else if (ListInit *lst = dyn_cast(ival)) { - unsigned i = 0; - for (auto elem : *lst) { - SmallVector next(idx.begin(), idx.end()); - next.push_back(i); - revres(argIdx, next, elem); - i++; - } - } else - assert(0); - }; - - for (auto argOpEn : enumerate(*argOps)) { - size_t argIdx = argOpEn.index(); - if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { - auto opName = resultRoot->getOperator()->getAsString(); - auto Def = cast(resultRoot->getOperator())->getDef(); - if (opName == "InactiveArgSpec" || - Def->isSubClassOf("InactiveArgSpec")) { - continue; - } - } - const char *curIndent = " "; - if (intrinsic == MLIRDerivatives) - os << curIndent << "if (!gutils->isConstantValue(" << origName - << "->getOperand(" << argIdx << "))) {\n"; - else - os << curIndent << "if (!gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; - initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value toadd = nullptr;\n"; - else - os << curIndent << INDENT << "Value *toadd = nullptr;\n"; - revres(argIdx, {}, argOpEn.value()); - - if (intrinsic == MLIRDerivatives) { - os << curIndent << INDENT << "if (toadd) gutils->addToDiffe(" - << origName << "->getOperand(" << argIdx << "), toadd, builder);\n"; - } else { - os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName - << ".getOperand(" << argIdx << "), toadd"; - os << ", Builder2, " << origName << ".getOperand(" << argIdx - << ")->getType());\n"; - } - os << curIndent << "}\n"; - } + emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); if (intrinsic != MLIRDerivatives) { + os << " auto found = gutils->invertedPointers.find(&(" << origName + << "));\n"; + os << " if (found != gutils->invertedPointers.end()) {\n"; + os << " PHINode* PN = cast(&*found->second);\n"; + os << " gutils->invertedPointers.erase(found);\n"; + os << " gutils->erase(PN);\n"; + os << " }\n"; os << " break;\n"; os << " }\n"; os << " case DerivativeMode::ReverseModePrimal:{\n"; + os << " auto found = gutils->invertedPointers.find(&(" << origName + << "));\n"; + os << " if (found != gutils->invertedPointers.end()) {\n"; + os << " PHINode* PN = cast(&*found->second);\n"; + os << " gutils->invertedPointers.erase(found);\n"; + os << " gutils->erase(PN);\n"; + os << " }\n"; // TODO os << " break;\n"; os << " }\n"; @@ -1875,6 +2002,83 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } if (intrinsic == MLIRDerivatives) { + const auto &actpatterns = + recordKeeper.getAllDerivedDefinitions("InactiveOp"); + for (auto &pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "Activity : \n"; + os << " public ActivityOpInterface::ExternalModel<" + << opName << "Activity, " << dialect << "::" << opName << "> {\n"; + os << " bool isInactive(mlir::Operation*) const { return true; }\n"; + os << " bool isArgInactive(mlir::Operation*, size_t) const { " + "return true; }\n"; + os << "};\n"; + } + const auto &cfpatterns = + recordKeeper.getAllDerivedDefinitions("ControlFlowOp"); + + const auto &mempatterns = + recordKeeper.getAllDerivedDefinitions("MemoryIdentityOp"); + + for (auto &pattern : cfpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + auto impl = pattern->getValueAsString("impl"); + os << "struct " << opName << "CF : \n"; + os << " public " + "ControlFlowAutoDiffOpInterface::ExternalModel<" + << opName << "CF, " << dialect << "::" << opName << "> {\n"; + os << impl << "\n"; + os << "};\n"; + } + + for (auto &pattern : mempatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + auto diffargs = pattern->getValueAsListOfInts("ptrargs"); + auto storedargs = pattern->getValueAsListOfInts("storedargs"); + os << "struct " << opName << "MemActivity : \n"; + os << " public ActivityOpInterface::ExternalModel<" << opName + << "MemActivity, " << dialect << "::" << opName << "> {\n"; + os << " bool isInactive(mlir::Operation* op) const {\n"; + os << " for (size_t i=0, len=op->getNumOperands(); igetValueAsDag("PatternToMatch"); + + if (tree->getOperator()->getAsString() != "Unimplemented") { + ListInit *argOps = pattern->getValueAsListInit("reverse"); + auto origName = "op"; + emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); + emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); + os << " return;\n"; + os << " }\n"; + os << " };\n"; + } + } + + const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp"); + + const auto &retpatterns = recordKeeper.getAllDerivedDefinitions("ReturnOp"); + + const auto ®tpatterns = + recordKeeper.getAllDerivedDefinitions("RegionTerminatorOp"); + + const auto &allocpatterns = + recordKeeper.getAllDerivedDefinitions("AllocationOp"); + os << "void registerInterfaces(MLIRContext* context) {\n"; for (Record *pattern : patterns) { auto opName = pattern->getValueAsString("opName"); @@ -1884,6 +2088,60 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " " << dialect << "::" << opName << "::attachInterface<" << opName << "RevDerivative>(*context);\n"; } + for (Record *pattern : actpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "Activity>(*context);\n"; + } + for (Record *pattern : cfpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "CF>(*context);\n"; + os << " registerAutoDiffUsingControlFlowInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : mempatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " " << dialect << "::" << opName << "::attachInterface<" << opName + << "MemActivity>(*context);\n"; + os << " registerAutoDiffUsingMemoryIdentityInterface<" << dialect + << "::" << opName; + for (auto storedarg : pattern->getValueAsListOfInts("storedargs")) + os << ", " << storedarg; + os << ">(*context);\n"; + DagInit *tree = pattern->getValueAsDag("PatternToMatch"); + if (tree->getOperator()->getAsString() != "Unimplemented") { + os << " " << dialect << "::" << opName << "::attachInterface<" + << opName << "RevDerivative>(*context);\n"; + } + } + for (Record *pattern : brpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingBranchInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : regtpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : retpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingReturnInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } + for (Record *pattern : allocpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingAllocationInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } os << "}\n"; } } @@ -1897,6 +2155,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDiffUse"); case CallDerivatives: patternNames = "CallPattern"; @@ -1935,6 +2194,7 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, case UpdateBlasDecl: case UpdateBlasTA: case GenBlasDiffUse: + case GenHeaderVariables: llvm_unreachable("Cannot use blas updaters inside emitDerivatives"); case CallDerivatives: { os << " if (("; @@ -1964,14 +2224,20 @@ void emitDiffUse(const RecordKeeper &recordKeeper, raw_ostream &os, StringRef name = cast(lst->getValues()[0])->getValue(); if (lst->size() >= 2) { auto min = cast(lst->getValues()[1])->getValue(); - int min_int; - min.getAsInteger(10, min_int); + int min_int = 0; + if (min.size() != 0 && min.getAsInteger(10, min_int)) { + PrintFatalError(pattern->getLoc(), + "Could not parse min llvm version as int"); + } if (min.size() != 0 && LLVM_VERSION_MAJOR < min_int) continue; if (lst->size() >= 3) { auto max = cast(lst->getValues()[2])->getValue(); - int max_int; - max.getAsInteger(10, max_int); + int max_int = 0; + if (max.size() != 0 && max.getAsInteger(10, max_int)) { + PrintFatalError(pattern->getLoc(), + "Could not parse max llvm version as int"); + } if (max.size() != 0 && LLVM_VERSION_MAJOR > max_int) continue; } @@ -2085,6 +2351,9 @@ static bool EnzymeTableGenMain(raw_ostream &os, RecordKeeper &records) { case UpdateBlasTA: emitBlasTAUpdater(records, os); return false; + case GenHeaderVariables: + emitHeaderIncludes(records, os); + return false; default: errs() << "unknown tablegen action!\n"; diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h index 368644ba0b5d..742a96d023ae 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.h @@ -24,6 +24,7 @@ enum ActionType { UpdateBlasDecl, UpdateBlasTA, GenBlasDiffUse, + GenHeaderVariables, }; void emitDiffUse(const llvm::RecordKeeper &recordKeeper, llvm::raw_ostream &os,