From 357ac253cff34f2e27fba7a9c19c1d64437858e0 Mon Sep 17 00:00:00 2001 From: "Eric G. Kratz" Date: Thu, 3 Oct 2024 10:39:48 -0400 Subject: [PATCH] Migrate code and build library (#4) * build commit * fixing build * change to has_idaklu * final idaklu working build * added uv backend to noxfile * refactored .gitignore * remove scikit * Remove excess files, change install dir * Fix build * Some cleanup * Get nox mostly working * setup tests * nox config * Move files * Attempt to fix install * Remove junk test * Rename test file * Fix workflow * Fix workflow * Try again * Replace pybind11 with submodule * Set pybind to a release * Add submodules to workflows * Move source files * Copy newer cpp code * Adding missed source files * Fix casadi path * Temporary fix * Adding wheel workflow * Simplify workflow for testing * Change title * Fix test command * Replace install path * Try another path * Restrict max python version to 3.12 * Attempt to add windows wheel build * Fix typo * Minor updates, probably still broken * Fix test error * Rename varible * Adding publishing step * Remove outdated workflows * Sync C++ code * Fix include path * Fix path * Hopefully fix windows * Updates for local IDEs --------- Co-authored-by: Santhosh Sundaram Co-authored-by: Santhosh <52504160+santacodes@users.noreply.github.com> --- .github/workflows/build_wheels.yml | 285 +++++ .github/workflows/release.yml | 39 - .github/workflows/unit_tests.yml | 33 +- .gitignore | 13 + .gitmodules | 3 + CMakeLists.txt | 199 ++++ FindSUNDIALS.cmake | 99 ++ FindSuiteSparse.cmake | 241 +++++ README.md | 2 +- install_KLU_Sundials.py | 338 ++++++ install_sundials.sh | 84 ++ noxfile.py | 134 +++ pybind11 | 1 + pyproject.toml | 27 +- setup.py | 292 ++++++ src/pybammsolvers/__init__.py | 6 + src/pybammsolvers/idaklu.cpp | 199 ++++ .../Expressions/Base/Expression.hpp | 69 ++ .../Expressions/Base/ExpressionSet.hpp | 86 ++ .../Expressions/Base/ExpressionTypes.hpp | 6 + .../Expressions/Casadi/CasadiFunctions.cpp | 76 ++ .../Expressions/Casadi/CasadiFunctions.hpp | 157 +++ .../idaklu_source/Expressions/Expressions.hpp | 6 + .../Expressions/IREE/IREEBaseFunction.hpp | 27 + .../Expressions/IREE/IREEFunction.hpp | 59 ++ .../Expressions/IREE/IREEFunctions.cpp | 230 +++++ .../Expressions/IREE/IREEFunctions.hpp | 146 +++ .../Expressions/IREE/ModuleParser.cpp | 91 ++ .../Expressions/IREE/ModuleParser.hpp | 55 + .../Expressions/IREE/iree_jit.cpp | 408 ++++++++ .../Expressions/IREE/iree_jit.hpp | 140 +++ .../idaklu_source/IDAKLUSolver.cpp | 1 + .../idaklu_source/IDAKLUSolver.hpp | 48 + .../idaklu_source/IDAKLUSolverGroup.cpp | 145 +++ .../idaklu_source/IDAKLUSolverGroup.hpp | 48 + .../idaklu_source/IDAKLUSolverOpenMP.hpp | 299 ++++++ .../idaklu_source/IDAKLUSolverOpenMP.inl | 977 ++++++++++++++++++ .../IDAKLUSolverOpenMP_solvers.cpp | 1 + .../IDAKLUSolverOpenMP_solvers.hpp | 131 +++ src/pybammsolvers/idaklu_source/IdakluJax.cpp | 264 +++++ src/pybammsolvers/idaklu_source/IdakluJax.hpp | 117 +++ src/pybammsolvers/idaklu_source/Options.cpp | 165 +++ src/pybammsolvers/idaklu_source/Options.hpp | 57 + src/pybammsolvers/idaklu_source/Solution.cpp | 1 + src/pybammsolvers/idaklu_source/Solution.hpp | 39 + .../idaklu_source/SolutionData.cpp | 99 ++ .../idaklu_source/SolutionData.hpp | 82 ++ src/pybammsolvers/idaklu_source/common.cpp | 19 + src/pybammsolvers/idaklu_source/common.hpp | 170 +++ .../idaklu_source/idaklu_solver.hpp | 246 +++++ src/pybammsolvers/idaklu_source/observe.cpp | 343 ++++++ src/pybammsolvers/idaklu_source/observe.hpp | 47 + .../idaklu_source/sundials_functions.hpp | 36 + .../idaklu_source/sundials_functions.inl | 352 +++++++ .../idaklu_source/sundials_legacy_wrapper.hpp | 94 ++ src/{pyidaklu => pybammsolvers}/version.py | 0 src/pyidaklu/__init__.py | 0 tests/test_dummy.py | 3 - tests/test_imports.py | 5 + vcpkg-configuration.json | 20 + vcpkg.json | 12 + 61 files changed, 7299 insertions(+), 73 deletions(-) create mode 100644 .github/workflows/build_wheels.yml delete mode 100644 .github/workflows/release.yml create mode 100644 .gitmodules create mode 100644 CMakeLists.txt create mode 100644 FindSUNDIALS.cmake create mode 100644 FindSuiteSparse.cmake create mode 100755 install_KLU_Sundials.py create mode 100644 install_sundials.sh create mode 100644 noxfile.py create mode 160000 pybind11 create mode 100644 setup.py create mode 100644 src/pybammsolvers/__init__.py create mode 100644 src/pybammsolvers/idaklu.cpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/Base/Expression.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSet.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionTypes.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/Expressions.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp create mode 100644 src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.hpp create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolver.cpp create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.cpp create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.hpp create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.hpp create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.cpp create mode 100644 src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.hpp create mode 100644 src/pybammsolvers/idaklu_source/IdakluJax.cpp create mode 100644 src/pybammsolvers/idaklu_source/IdakluJax.hpp create mode 100644 src/pybammsolvers/idaklu_source/Options.cpp create mode 100644 src/pybammsolvers/idaklu_source/Options.hpp create mode 100644 src/pybammsolvers/idaklu_source/Solution.cpp create mode 100644 src/pybammsolvers/idaklu_source/Solution.hpp create mode 100644 src/pybammsolvers/idaklu_source/SolutionData.cpp create mode 100644 src/pybammsolvers/idaklu_source/SolutionData.hpp create mode 100644 src/pybammsolvers/idaklu_source/common.cpp create mode 100644 src/pybammsolvers/idaklu_source/common.hpp create mode 100644 src/pybammsolvers/idaklu_source/idaklu_solver.hpp create mode 100644 src/pybammsolvers/idaklu_source/observe.cpp create mode 100644 src/pybammsolvers/idaklu_source/observe.hpp create mode 100644 src/pybammsolvers/idaklu_source/sundials_functions.hpp create mode 100644 src/pybammsolvers/idaklu_source/sundials_functions.inl create mode 100644 src/pybammsolvers/idaklu_source/sundials_legacy_wrapper.hpp rename src/{pyidaklu => pybammsolvers}/version.py (100%) delete mode 100644 src/pyidaklu/__init__.py delete mode 100644 tests/test_dummy.py create mode 100644 tests/test_imports.py create mode 100644 vcpkg-configuration.json create mode 100644 vcpkg.json diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml new file mode 100644 index 0000000..19e6be7 --- /dev/null +++ b/.github/workflows/build_wheels.yml @@ -0,0 +1,285 @@ +name: Build wheels +on: + release: + types: [ published ] + workflow_dispatch: + pull_request: + branches: + - "main" + +env: + CIBW_BUILD_VERBOSITY: 2 + CIBW_BUILD_FRONTEND: "pip; args: --no-build-isolation" + # Skip PyPy and MUSL builds in any and all jobs + CIBW_SKIP: "pp* *musllinux*" + FORCE_COLOR: 3 + +jobs: + build_windows_wheels: + name: Wheels (windows-latest) + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: 'true' + - uses: actions/setup-python@v5 + with: + python-version: 3.11 + + - name: Get number of cores on Windows + id: get_num_cores + shell: python + run: | + from os import environ, cpu_count + num_cpus = cpu_count() + output_file = environ['GITHUB_OUTPUT'] + with open(output_file, "a", encoding="utf-8") as output_stream: + output_stream.write(f"count={num_cpus}\n") + + - name: Install vcpkg on Windows + run: | + cd C:\ + rm -r -fo 'C:\vcpkg' + git clone https://github.com/microsoft/vcpkg + cd vcpkg + .\bootstrap-vcpkg.bat + + - name: Build wheels on Windows + run: pipx run cibuildwheel --output-dir wheelhouse + env: + CIBW_ENVIRONMENT: > + PYBAMMSOLVERS_USE_VCPKG=ON + VCPKG_ROOT_DIR=C:\vcpkg + VCPKG_DEFAULT_TRIPLET=x64-windows-static-md + VCPKG_FEATURE_FLAGS=manifests,registries + CMAKE_GENERATOR="Visual Studio 17 2022" + CMAKE_GENERATOR_PLATFORM=x64 + CMAKE_BUILD_PARALLEL_LEVEL=${{ steps.get_num_cores.outputs.count }} + CIBW_ARCHS: AMD64 + CIBW_BEFORE_BUILD: python -m pip install setuptools wheel delvewheel # skip CasADi and CMake + CIBW_REPAIR_WHEEL_COMMAND: delvewheel repair --add-path C:/Windows/System32 -w {dest_dir} {wheel} + CIBW_TEST_EXTRAS: "dev" + CIBW_TEST_COMMAND: | + python -m pytest {project}/tests + - name: Upload Windows wheels + uses: actions/upload-artifact@v4 + with: + name: wheels_windows + path: ./wheelhouse/*.whl + if-no-files-found: error + + build_manylinux_wheels: + name: Wheels (linux-amd64) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: 'true' + + - uses: actions/setup-python@v5 + with: + python-version: 3.11 + + - name: Build wheels on Linux + run: pipx run cibuildwheel --output-dir wheelhouse + env: + CIBW_ARCHS_LINUX: x86_64 + CIBW_BEFORE_ALL_LINUX: > + yum -y install openblas-devel lapack-devel && + bash install_sundials.sh + CIBW_BEFORE_BUILD_LINUX: python -m pip install cmake casadi setuptools wheel + CIBW_REPAIR_WHEEL_COMMAND_LINUX: auditwheel repair -w {dest_dir} {wheel} + CIBW_TEST_EXTRAS: "dev" + CIBW_TEST_COMMAND: | + set -e -x + python -m pytest {project}/tests + + - name: Upload wheels for Linux + uses: actions/upload-artifact@v4 + with: + name: wheels_manylinux + path: ./wheelhouse/*.whl + if-no-files-found: error + + build_macos_wheels: + name: Wheels (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [macos-13, macos-14] + fail-fast: false + steps: + - uses: actions/checkout@v4 + with: + submodules: 'true' + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install cibuildwheel + run: python -m pip install cibuildwheel + + - name: Build wheels on macOS + shell: bash + run: | + set -e -x + + # Set LLVM-OpenMP URL + if [[ $(uname -m) == "x86_64" ]]; then + OPENMP_URL="https://anaconda.org/conda-forge/llvm-openmp/11.1.0/download/osx-64/llvm-openmp-11.1.0-hda6cdc1_1.tar.bz2" + elif [[ $(uname -m) == "arm64" ]]; then + OPENMP_URL="https://anaconda.org/conda-forge/llvm-openmp/11.1.0/download/osx-arm64/llvm-openmp-11.1.0-hf3c4609_1.tar.bz2" + fi + + # Download gfortran with proper macOS minimum version (11.0) + if [[ $(uname -m) == "x86_64" ]]; then + GFORTRAN_URL="https://github.com/isuruf/gcc/releases/download/gcc-11.3.0-2/gfortran-darwin-x86_64-native.tar.gz" + KNOWN_SHA256="981367dd0ad4335613e91bbee453d60b6669f5d7e976d18c7bdb7f1966f26ae4 gfortran.tar.gz" + elif [[ $(uname -m) == "arm64" ]]; then + GFORTRAN_URL="https://github.com/isuruf/gcc/releases/download/gcc-11.3.0-2/gfortran-darwin-arm64-native.tar.gz" + KNOWN_SHA256="84364eee32ba843d883fb8124867e2bf61a0cd73b6416d9897ceff7b85a24604 gfortran.tar.gz" + fi + + # Validate gfortran tarball + curl -L $GFORTRAN_URL -o gfortran.tar.gz + if ! echo "$KNOWN_SHA256" != "$(shasum --algorithm 256 gfortran.tar.gz)"; then + echo "Checksum failed" + exit 1 + fi + + mkdir -p gfortran_installed + tar -xv -C gfortran_installed/ -f gfortran.tar.gz + + if [[ $(uname -m) == "x86_64" ]]; then + export FC=$(pwd)/gfortran_installed/gfortran-darwin-x86_64-native/bin/gfortran + export PATH=$(pwd)/gfortran_installed/gfortran-darwin-x86_64-native/bin:$PATH + elif [[ $(uname -m) == "arm64" ]]; then + export FC=$(pwd)/gfortran_installed/gfortran-darwin-arm64-native/bin/gfortran + export PATH=$(pwd)/gfortran_installed/gfortran-darwin-arm64-native/bin:$PATH + fi + + # link libgfortran dylibs and place them in $SOLVER_LIB_PATH + # and then change rpath for each of them + # Note: libgcc_s.1.dylib not available on macOS arm64; skip for now + SOLVER_LIB_PATH=$(pwd)/.idaklu/lib + mkdir -p $SOLVER_LIB_PATH + if [[ $(uname -m) == "x86_64" ]]; then + lib_dir=$(pwd)/gfortran_installed/gfortran-darwin-x86_64-native/lib + for lib in libgfortran.5.dylib libgfortran.dylib libquadmath.0.dylib libquadmath.dylib libgcc_s.1.dylib libgcc_s.1.1.dylib; do + cp $lib_dir/$lib $SOLVER_LIB_PATH/ + install_name_tool -id $SOLVER_LIB_PATH/$lib $SOLVER_LIB_PATH/$lib + codesign --force --sign - $SOLVER_LIB_PATH/$lib + done + elif [[ $(uname -m) == "arm64" ]]; then + lib_dir=$(pwd)/gfortran_installed/gfortran-darwin-arm64-native/lib + for lib in libgfortran.5.dylib libgfortran.dylib libquadmath.0.dylib libquadmath.dylib libgcc_s.1.1.dylib; do + cp $lib_dir/$lib $SOLVER_LIB_PATH/. + install_name_tool -id $SOLVER_LIB_PATH/$lib $SOLVER_LIB_PATH/$lib + codesign --force --sign - $SOLVER_LIB_PATH/$lib + done + fi + + export SDKROOT=${SDKROOT:-$(xcrun --show-sdk-path)} + + # Can't download LLVM-OpenMP directly, use conda/mamba and set environment variables + brew install miniforge + mamba create -n pybammsolvers-dev $OPENMP_URL + if [[ $(uname -m) == "x86_64" ]]; then + PREFIX="/usr/local/Caskroom/miniforge/base/envs/pybammsolvers-dev" + elif [[ $(uname -m) == "arm64" ]]; then + PREFIX="/opt/homebrew/Caskroom/miniforge/base/envs/pybammsolvers-dev" + fi + + # Copy libomp.dylib from PREFIX to $SOLVER_LIB_PATH, needed for wheel repair + cp $PREFIX/lib/libomp.dylib $SOLVER_LIB_PATH/. + install_name_tool -id $SOLVER_LIB_PATH/libomp.dylib $SOLVER_LIB_PATH/libomp.dylib + codesign --force --sign - $SOLVER_LIB_PATH/libomp.dylib + + export CC=/usr/bin/clang + export CXX=/usr/bin/clang++ + export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" + export CFLAGS="$CFLAGS -I$PREFIX/include" + export CXXFLAGS="$CXXFLAGS -I$PREFIX/include" + export LDFLAGS="$LDFLAGS -L$PREFIX/lib -lomp" + + # cibuildwheel not recognising its environment variable, so set manually + export CIBUILDWHEEL="1" + + python install_KLU_Sundials.py + python -m cibuildwheel --output-dir wheelhouse + env: + # 10.13 for Intel (macos-12/macos-13), 11.0 for Apple Silicon (macos-14 and macos-latest) + MACOSX_DEPLOYMENT_TARGET: ${{ matrix.os == 'macos-14' && '11.0' || '10.13' }} + CIBW_ARCHS_MACOS: auto + CIBW_BEFORE_BUILD: python -m pip install cmake casadi setuptools wheel delocate + CIBW_REPAIR_WHEEL_COMMAND: | + if [[ $(uname -m) == "x86_64" ]]; then + delocate-listdeps {wheel} && delocate-wheel -v -w {dest_dir} {wheel} + elif [[ $(uname -m) == "arm64" ]]; then + # Use higher macOS target for now since casadi/libc++.1.0.dylib is still not fixed + delocate-listdeps {wheel} && delocate-wheel -v -w {dest_dir} {wheel} --require-target-macos-version 11.1 + for file in {dest_dir}/*.whl; do mv "$file" "${file//macosx_11_1/macosx_11_0}"; done + fi + CIBW_TEST_EXTRAS: "dev" + CIBW_TEST_COMMAND: | + set -e -x + python -m pytest {project}/tests + + - name: Upload wheels for macOS (amd64, arm64) + uses: actions/upload-artifact@v4 + with: + name: wheels_${{ matrix.os }} + path: ./wheelhouse/*.whl + if-no-files-found: error + + build_sdist: + name: Build SDist + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.12 + + - name: Build SDist + run: pipx run build --sdist + + - name: Upload SDist + uses: actions/upload-artifact@v4.4.0 + with: + name: sdist + path: ./dist/*.tar.gz + if-no-files-found: error + + publish_pypi: + if: github.repository == 'pybamm-team/pybammsolvers' + name: Upload package to PyPI + needs: [ + build_manylinux_wheels, + build_macos_wheels, + build_windows_wheels, + build_sdist + ] + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/pybammsolvers + permissions: + id-token: write + + steps: + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts + merge-multiple: true + + - name: Sanity check downloaded artifacts + run: ls -lA artifacts/ + + - name: Publish to PyPI + if: github.event_name == 'release' + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: artifacts/ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml deleted file mode 100644 index fa569e8..0000000 --- a/.github/workflows/release.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: Release - -on: - release: - types: [published] - - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - Release: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ["3.12"] - - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Build wheel - run: | - pip wheel . --no-deps -w deploy - - - name: Upload package - uses: actions/upload-artifact@v2 - with: - name: wheels - path: deploy/ - if-no-files-found: error - - - name: Publish on PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - packages-dir: deploy/ diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index ccc2bbe..958e3ca 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -10,40 +10,25 @@ concurrency: cancel-in-progress: true jobs: - pytest-linux: + pytest: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.11"] steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install -e ".[dev]" - - name: Test with pytest - run: | - pytest - - pytest-windows: - runs-on: windows-latest - strategy: - matrix: - python-version: ["3.10", "3.11", "3.12"] - - steps: - - uses: actions/checkout@v3 + submodules: 'true' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install -e ".[dev]" - - name: Test with pytest + sudo apt-get install gfortran gcc libopenblas-dev + pip install nox + + - name: Build and test run: | - pytest + nox diff --git a/.gitignore b/.gitignore index 138f392..7caf519 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,17 @@ .idea .vscode +cmake-build-* *.egg-info __pycache__ +.venv +.env +.nox + +setup.log +install_KLU_Sundials +.idaklu +build + +# Extra directories for local work +workspace +deploy diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..8e1aff6 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "pybind11"] + path = pybind11 + url = git@github.com:pybind/pybind11.git diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..377b371 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,199 @@ +cmake_minimum_required(VERSION 3.13) +cmake_policy(SET CMP0074 NEW) +set(CMAKE_VERBOSE_MAKEFILE ON) + +if(DEFINED Python_EXECUTABLE) + set(PYTHON_EXECUTABLE ${Python_EXECUTABLE}) +endif() + +if(DEFINED ENV{VCPKG_ROOT_DIR} AND NOT DEFINED VCPKG_ROOT_DIR) + set(VCPKG_ROOT_DIR "$ENV{VCPKG_ROOT_DIR}" + CACHE STRING "Vcpkg root directory") +endif() + +if(DEFINED VCPKG_ROOT_DIR) + set(CMAKE_TOOLCHAIN_FILE ${VCPKG_ROOT_DIR}/scripts/buildsystems/vcpkg.cmake + CACHE STRING "Vcpkg toolchain file") +endif() + +if(DEFINED ENV{VCPKG_DEFAULT_TRIPLET} AND NOT DEFINED VCPKG_TARGET_TRIPLET) + set(VCPKG_TARGET_TRIPLET "$ENV{VCPKG_DEFAULT_TRIPLET}" + CACHE STRING "Vcpkg target triplet") +endif() + +project(idaklu) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_EXPORT_COMPILE_COMMANDS 1) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +if(NOT MSVC) + # MSVC does not support variable length arrays (vla) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror=vla") +endif() + +# casadi seems to compile without the newer versions of std::string +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) + +if(NOT PYBIND11_DIR) + set(PYBIND11_DIR pybind11) +endif() +add_subdirectory(${PYBIND11_DIR}) + +# Check Casadi build flag +if(NOT DEFINED PYBAMM_IDAKLU_EXPR_CASADI) + set(PYBAMM_IDAKLU_EXPR_CASADI ON) +endif() +message("PYBAMM_IDAKLU_EXPR_CASADI: ${PYBAMM_IDAKLU_EXPR_CASADI}") + +# Casadi PyBaMM source files +set(IDAKLU_EXPR_CASADI_SOURCE_FILES "") +if(${PYBAMM_IDAKLU_EXPR_CASADI} STREQUAL "ON" ) + add_compile_definitions(CASADI_ENABLE) + set(IDAKLU_EXPR_CASADI_SOURCE_FILES + src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp + src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.hpp + ) +endif() + +# Check IREE build flag +if(NOT DEFINED PYBAMM_IDAKLU_EXPR_IREE) + set(PYBAMM_IDAKLU_EXPR_IREE OFF) +endif() +message("PYBAMM_IDAKLU_EXPR_IREE: ${PYBAMM_IDAKLU_EXPR_IREE}") + +# IREE (MLIR expression evaluation) PyBaMM source files +set(IDAKLU_EXPR_IREE_SOURCE_FILES "") +if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) + add_compile_definitions(IREE_ENABLE) + # Source file list + set(IDAKLU_EXPR_IREE_SOURCE_FILES + src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp + src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.hpp + src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp + src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.hpp + src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp + src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp + ) +endif() + +# The complete (all dependencies) sources list should be mirrored in setup.py +pybind11_add_module(idaklu + # pybind11 interface + src/pybammsolvers/idaklu.cpp + # IDAKLU solver (SUNDIALS) + src/pybammsolvers/idaklu_source/idaklu_solver.hpp + src/pybammsolvers/idaklu_source/IDAKLUSolver.cpp + src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp + src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.cpp + src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.hpp + src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl + src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.hpp + src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.cpp + src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.hpp + src/pybammsolvers/idaklu_source/sundials_functions.inl + src/pybammsolvers/idaklu_source/sundials_functions.hpp + src/pybammsolvers/idaklu_source/IdakluJax.cpp + src/pybammsolvers/idaklu_source/IdakluJax.hpp + src/pybammsolvers/idaklu_source/common.hpp + src/pybammsolvers/idaklu_source/common.cpp + src/pybammsolvers/idaklu_source/Solution.cpp + src/pybammsolvers/idaklu_source/Solution.hpp + src/pybammsolvers/idaklu_source/SolutionData.cpp + src/pybammsolvers/idaklu_source/SolutionData.hpp + src/pybammsolvers/idaklu_source/observe.cpp + src/pybammsolvers/idaklu_source/observe.hpp + src/pybammsolvers/idaklu_source/Options.hpp + src/pybammsolvers/idaklu_source/Options.cpp + # IDAKLU expressions / function evaluation [abstract] + src/pybammsolvers/idaklu_source/Expressions/Expressions.hpp + src/pybammsolvers/idaklu_source/Expressions/Base/Expression.hpp + src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSet.hpp + src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionTypes.hpp + # IDAKLU expressions - concrete implementations + ${IDAKLU_EXPR_CASADI_SOURCE_FILES} + ${IDAKLU_EXPR_IREE_SOURCE_FILES} +) + +if (NOT DEFINED USE_PYTHON_CASADI) + set(USE_PYTHON_CASADI TRUE) +endif() + + +execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c + "import os; import sysconfig; print(os.path.join(sysconfig.get_path('purelib'), 'casadi', 'cmake'))" + OUTPUT_VARIABLE CASADI_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE) + +if (CASADI_DIR) + file(TO_CMAKE_PATH ${CASADI_DIR} CASADI_DIR) + message("Found Python casadi path: ${CASADI_DIR}") +else() + message(FATAL_ERROR "Did not find casadi path}") +endif() + +if(${USE_PYTHON_CASADI}) + message("Trying to link against Python casadi package in ${CASADI_DIR}") + find_package(casadi CONFIG PATHS ${CASADI_DIR} REQUIRED NO_DEFAULT_PATH) +else() + message("Trying to link against any casadi package apart from the Python one") + set(CMAKE_IGNORE_PATH "${CASADI_DIR}/cmake") + find_package(casadi CONFIG REQUIRED) +endif() + +set_target_properties( + idaklu PROPERTIES + INSTALL_RPATH "${CASADI_DIR}" + INSTALL_RPATH_USE_LINK_PATH TRUE +) + +# openmp +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + execute_process( + COMMAND "brew" "--prefix" + OUTPUT_VARIABLE HOMEBREW_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE) + if (OpenMP_ROOT) + set(OpenMP_ROOT "${OpenMP_ROOT}:${HOMEBREW_PREFIX}/opt/libomp") + else() + set(OpenMP_ROOT "${HOMEBREW_PREFIX}/opt/libomp") + endif() +endif() +find_package(OpenMP) +if(OpenMP_CXX_FOUND) + target_link_libraries(idaklu PRIVATE OpenMP::OpenMP_CXX) +endif() + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}) +# Sundials +find_package(SUNDIALS REQUIRED) +message("SUNDIALS found in ${SUNDIALS_INCLUDE_DIR}: ${SUNDIALS_LIBRARIES}") +target_include_directories(idaklu PRIVATE ${SUNDIALS_INCLUDE_DIR}) +target_link_libraries(idaklu PRIVATE ${SUNDIALS_LIBRARIES} casadi) + +# link suitesparse +# if using vcpkg, use config mode to +# find suitesparse. Otherwise, use FindSuiteSparse module +if(DEFINED VCPKG_ROOT_DIR) + find_package(SuiteSparse CONFIG REQUIRED) +else() + find_package(SuiteSparse REQUIRED) + message("SuiteSparse found in ${SuiteSparse_INCLUDE_DIRS}: ${SuiteSparse_LIBRARIES}") +endif() +include_directories(${SuiteSparse_INCLUDE_DIRS}) +target_link_libraries(idaklu PRIVATE ${SuiteSparse_LIBRARIES}) + +# IREE (MLIR compiler and runtime library) build settings +if(${PYBAMM_IDAKLU_EXPR_IREE} STREQUAL "ON" ) + set(IREE_BUILD_COMPILER ON) + set(IREE_BUILD_TESTS OFF) + set(IREE_BUILD_SAMPLES OFF) + add_subdirectory(iree EXCLUDE_FROM_ALL) + set(IREE_COMPILER_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/iree/compiler") + target_include_directories(idaklu SYSTEM PRIVATE "${IREE_COMPILER_ROOT}/bindings/c/iree/compiler") + target_compile_options(idaklu PRIVATE ${IREE_DEFAULT_COPTS}) + target_link_libraries(idaklu PRIVATE iree_compiler_bindings_c_loader) + target_link_libraries(idaklu PRIVATE iree_runtime_runtime) +endif() diff --git a/FindSUNDIALS.cmake b/FindSUNDIALS.cmake new file mode 100644 index 0000000..4f8c52b --- /dev/null +++ b/FindSUNDIALS.cmake @@ -0,0 +1,99 @@ +# This module is adapted from that in CADET (``_): + +# .. cmake_module:: + +# Find SUNDIALS, the SUite of Nonlinear and DIfferential/ALgebraic equation Solvers. +# +# The module looks for the following sundials components +# +# * sundials_ida +# * sundials_sunlinsolklu +# * sundials_sunlinsoldense +# * sundials_sunlinsollapackdense +# * sundials_sunmatrix_sparse +# * sundials_nvecserial +# +# To provide the module with a hint about where to find your SUNDIALS installation, +# you can set the environment variable :code:`SUNDIALS_ROOT`. The FindSUNDIALS module will +# then look in this path when searching for SUNDIALS paths and libraries. +# This behavior is defined in CMake >= 3.12, see policy CMP0074. +# It is replicated for older versions by adding the :code:`SUNDIALS_ROOT` variable to the +# :code:`PATHS` entry. +# +# This module will define the following variables: +# :code:`SUNDIALS_INCLUDE_DIRS` - Location of the SUNDIALS includes +# :code:`SUNDIALS_LIBRARIES` - Required libraries for all requested components + +# List of the valid SUNDIALS components + +# find the SUNDIALS include directories +find_path(SUNDIALS_INCLUDE_DIR + NAMES + idas/idas.h + sundials/sundials_math.h + sundials/sundials_types.h + sunlinsol/sunlinsol_klu.h + sunlinsol/sunlinsol_dense.h + sunlinsol/sunlinsol_spbcgs.h + sunlinsol/sunlinsol_lapackdense.h + sunmatrix/sunmatrix_sparse.h + PATH_SUFFIXES + include + PATHS + ${SUNDIALS_ROOT} + ) + +set(SUNDIALS_WANT_COMPONENTS + sundials_idas + sundials_sunlinsolklu + sundials_sunlinsoldense + sundials_sunlinsolspbcgs + sundials_sunlinsollapackdense + sundials_sunmatrixsparse + sundials_nvecserial + sundials_nvecopenmp + ) + +# find the SUNDIALS libraries +foreach(LIB ${SUNDIALS_WANT_COMPONENTS}) + if (UNIX AND SUNDIALS_PREFER_STATIC_LIBRARIES) + # According to bug 1643 on the CMake bug tracker, this is the + # preferred method for searching for a static library. + # See http://www.cmake.org/Bug/view.php?id=1643. We search + # first for the full static library name, but fall back to a + # generic search on the name if the static search fails. + set(THIS_LIBRARY_SEARCH lib${LIB}.a ${LIB}) + else() + set(THIS_LIBRARY_SEARCH ${LIB}) + endif() + + find_library(SUNDIALS_${LIB}_LIBRARY + NAMES ${THIS_LIBRARY_SEARCH} + PATH_SUFFIXES + lib + Lib + PATHS + ${SUNDIALS_ROOT} + ) + + set(SUNDIALS_${LIB}_FOUND FALSE) + if (SUNDIALS_${LIB}_LIBRARY) + list(APPEND SUNDIALS_LIBRARIES ${SUNDIALS_${LIB}_LIBRARY}) + set(SUNDIALS_${LIB}_FOUND TRUE) + endif() + mark_as_advanced(SUNDIALS_${LIB}_LIBRARY) +endforeach() + +mark_as_advanced( + SUNDIALS_LIBRARIES + SUNDIALS_INCLUDE_DIR +) + +# behave like a CMake module is supposed to behave +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + "SUNDIALS" + FOUND_VAR SUNDIALS_FOUND + REQUIRED_VARS SUNDIALS_INCLUDE_DIR SUNDIALS_LIBRARIES + HANDLE_COMPONENTS +) diff --git a/FindSuiteSparse.cmake b/FindSuiteSparse.cmake new file mode 100644 index 0000000..c5275c0 --- /dev/null +++ b/FindSuiteSparse.cmake @@ -0,0 +1,241 @@ +# This CMakeFile is adapted from that in dune-common: + +# .. cmake_module:: +# +# Find the SuiteSparse libraries like UMFPACK or SPQR. +# +# Example which tries to find Suite Sparse's UMFPack component: +# +# :code:`find_package(SuiteSparse OPTIONAL_COMPONENTS UMFPACK)` +# +# `OPTIONAL_COMPONENTS` +# A list of components. Components are: +# AMD, BTF, CAMD, CCOLAMD, CHOLMOD, COLAMD, CXSPARSE, +# KLU, LDL, RBIO, SPQR, UMFPACK +# +# :ref:`SuiteSparse_ROOT` +# Path list to search for SuiteSparse +# +# Sets the following variables: +# +# :code:`SuiteSparse_FOUND` +# True if SuiteSparse was found. +# +# :code:`SuiteSparse_INCLUDE_DIRS` +# Path to the SuiteSparse include dirs. +# +# :code:`SuiteSparse_LIBRARIES` +# Name of the SuiteSparse libraries. +# +# :code:`SuiteSparse__FOUND` +# Whether was found as part of SuiteSparse. +# +# .. cmake_variable:: SuiteSparse_ROOT +# +# You may set this variable to have :ref:`FindSuiteSparse` look +# for SuiteSparse in the given path before inspecting +# system paths. +# + +find_package(BLAS QUIET) + +# look for desired componenents +set(SUITESPARSE_COMPONENTS ${SuiteSparse_FIND_COMPONENTS}) + +# resolve inter-component dependencies +list(FIND SUITESPARSE_COMPONENTS "UMFPACK" WILL_USE_UMFPACK) +if(NOT WILL_USE_UMFPACK EQUAL -1) + list(APPEND SUITESPARSE_COMPONENTS AMD CHOLMOD) +endif() +list(FIND SUITESPARSE_COMPONENTS "CHOLMOD" WILL_USE_CHOLMOD) +if(NOT WILL_USE_CHOLMOD EQUAL -1) + list(APPEND SUITESPARSE_COMPONENTS AMD CAMD COLAMD CCOLAMD) +endif() + +if(SUITESPARSE_COMPONENTS) + list(REMOVE_DUPLICATES SUITESPARSE_COMPONENTS) +endif() + +# find SuiteSparse config: +# look for library at positions given by the user +find_library(SUITESPARSE_CONFIG_LIB + NAMES "suitesparseconfig" + PATHS ${SuiteSparse_ROOT} + PATH_SUFFIXES "lib" "lib32" "lib64" "Lib" + NO_DEFAULT_PATH +) +# now also include the default paths +find_library(SUITESPARSE_CONFIG_LIB + NAMES "suitesparseconfig" + PATH_SUFFIXES "lib" "lib32" "lib64" "Lib" +) + +#look for header files at positions given by the user +find_path(SUITESPARSE_INCLUDE_DIR + NAMES "SuiteSparse_config.h" + PATHS ${SuiteSparse_ROOT} + PATH_SUFFIXES "SuiteSparse_config" "SuiteSparse_config/include" "suitesparse" "include" "src" "SuiteSparse_config/Include" + NO_DEFAULT_PATH +) +#now also look for default paths +find_path(SUITESPARSE_INCLUDE_DIR + NAMES "SuiteSparse_config.h" + PATH_SUFFIXES "SuiteSparse_config" "SuiteSparse_config/include" "suitesparse" "include" "src" "SuiteSparse_config/Include" +) + +foreach(_component ${SUITESPARSE_COMPONENTS}) + string(TOLOWER ${_component} _componentLower) + + #look for library at positions given by the user + find_library(${_component}_LIBRARY + NAMES "${_componentLower}" + PATHS ${SuiteSparse_ROOT} + PATH_SUFFIXES "lib" "lib32" "lib64" "${_component}" "${_component}/Lib" + NO_DEFAULT_PATH + ) + #now also include the default paths + find_library(${_component}_LIBRARY + NAMES "${_componentLower}" + PATH_SUFFIXES "lib" "lib32" "lib64" "${_component}" "${_component}/Lib" + ) + + #look for header files at positions given by the user + find_path(${_component}_INCLUDE_DIR + NAMES "${_componentLower}.h" + PATHS ${SuiteSparse_ROOT} + PATH_SUFFIXES "${_componentLower}" "include/${_componentLower}" "suitesparse" "include" "src" "${_component}" "${_component}/Include" + NO_DEFAULT_PATH + ) + #now also look for default paths + find_path(${_component}_INCLUDE_DIR + NAMES "${_componentLower}.h" + PATH_SUFFIXES "${_componentLower}" "include/${_componentLower}" "suitesparse" "include" "${_component}" "${_component}/Include" + ) +endforeach() + +# SPQR has different header file name SuiteSparseQR.hpp +#look for header files at positions given by the user +find_path(SPQR_INCLUDE_DIR + NAMES "SuiteSparseQR.hpp" + PATHS ${SuiteSparse_ROOT} + PATH_SUFFIXES "spqr" "include/spqr" "suitesparse" "include" "src" "SPQR" "SPQR/Include" + NO_DEFAULT_PATH +) +#now also look for default paths +find_path(SPQR_INCLUDE_DIR + NAMES "SuiteSparseQR.hpp" + PATH_SUFFIXES "spqr" "include/spqr" "suitesparse" "include" "SPQR" "SPQR/Include" +) + +# resolve inter-modular dependencies + +# CHOLMOD requires AMD, COLAMD; CAMD and CCOLAMD are optional +if(CHOLMOD_LIBRARY) + if(NOT (AMD_LIBRARY AND COLAMD_LIBRARY)) + message(WARNING "CHOLMOD requires AMD and COLAMD which were not found, skipping the test.") + set(SuiteSparse_CHOLMOD_FOUND "CHOLMOD requires AMD and COLAMD-NOTFOUND") + endif() + + list(APPEND CHOLMOD_LIBRARY ${AMD_LIBRARY} ${COLAMD_LIBRARY}) + if(CAMD_LIBRARY) + list(APPEND CHOLMOD_LIBRARY ${CAMD_LIBRARY}) + endif() + if(CCOLAMD_LIBRARY) + list(APPEND CHOLMOD_LIBRARY ${CCOLAMD_LIBRARY}) + endif() + list(REVERSE CHOLMOD_LIBRARY) + # remove duplicates + list(REMOVE_DUPLICATES CHOLMOD_LIBRARY) + list(REVERSE CHOLMOD_LIBRARY) +endif() + +# UMFPack requires AMD, can depend on CHOLMOD +if(UMFPACK_LIBRARY) + # check wether cholmod was found + if(CHOLMOD_LIBRARY) + list(APPEND UMFPACK_LIBRARY ${CHOLMOD_LIBRARY}) + else() + list(APPEND UMFPACK_LIBRARY ${AMD_LIBRARY}) + endif() + list(REVERSE UMFPACK_LIBRARY) + # remove duplicates + list(REMOVE_DUPLICATES UMFPACK_LIBRARY) + list(REVERSE UMFPACK_LIBRARY) +endif() + +# check wether everything was found +foreach(_component ${SUITESPARSE_COMPONENTS}) + # variable used for component handling + set(SuiteSparse_${_component}_FOUND (${_component}_LIBRARY AND ${_component}_INCLUDE_DIR)) + set(HAVE_SUITESPARSE_${_component} SuiteSparse_${_component}_FOUND) + if(SuiteSparse_${_component}_FOUND) + list(APPEND SUITESPARSE_INCLUDE_DIR "${${_component}_INCLUDE_DIR}") + list(APPEND SUITESPARSE_LIBRARY "${${_component}_LIBRARY}") + endif() + + mark_as_advanced( + HAVE_SUITESPARSE_${_component} + SuiteSparse_${_component}_FOUND + ${_component}_INCLUDE_DIR + ${_component}_LIBRARY) +endforeach() + +list(APPEND SUITESPARSE_LIBRARY ${SUITESPARSE_CONFIG_LIB}) + +# make them unique +if(SUITESPARSE_INCLUDE_DIR) + list(REMOVE_DUPLICATES SUITESPARSE_INCLUDE_DIR) +endif() +if(SUITESPARSE_LIBRARY) + list(REVERSE SUITESPARSE_LIBRARY) + list(REMOVE_DUPLICATES SUITESPARSE_LIBRARY) + list(REVERSE SUITESPARSE_LIBRARY) +endif() + +# behave like a CMake module is supposed to behave +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + "SuiteSparse" + FOUND_VAR SuiteSparse_FOUND + REQUIRED_VARS + BLAS_FOUND + SUITESPARSE_INCLUDE_DIR + SUITESPARSE_LIBRARY + HANDLE_COMPONENTS +) + +mark_as_advanced( + SUITESPARSE_INCLUDE_DIR + SUITESPARSE_LIBRARY + SUITESPARSE_CONFIG_LIB + WILL_USE_CHOLMOD + WILL_USE_UMFPACK) + +# if both headers and library are found, store results +if(SuiteSparse_FOUND) + set(SuiteSparse_LIBRARIES ${SUITESPARSE_LIBRARY}) + set(SuiteSparse_INCLUDE_DIRS ${SUITESPARSE_INCLUDE_DIR}) + # log result + file(APPEND ${CMAKE_BINARY_DIR}${CMAKE_FILES_DIRECTORY}/CMakeOutput.log + "Determining location of SuiteSparse succeded:\n" + "Include directory: ${SuiteSparse_INCLUDE_DIRS}\n" + "Library directory: ${SuiteSparse_LIBRARIES}\n\n") + set(SuiteSparse_COMPILER_FLAGS) + foreach(dir ${SuiteSparse_INCLUDE_DIRS}) + set(SuiteSparse_COMPILER_FLAGS "${SuiteSparse_COMPILER_FLAGS} -I${dir}/") + endforeach() + set(SuiteSparse_DUNE_COMPILE_FLAGS ${SuiteSparse_COMPILER_FLAGS} + CACHE STRING "Compile Flags used by DUNE when compiling with SuiteSparse programs") + set(SuiteSparse_DUNE_LIBRARIES ${BLAS_LIBRARIES} ${SuiteSparse_LIBRARIES} + CACHE STRING "Libraries used by DUNE when linking SuiteSparse programs") +else() + # log errornous result + file(APPEND ${CMAKE_BINARY_DIR}${CMAKES_FILES_DIRECTORY}/CMakeError.log + "Determing location of SuiteSparse failed:\n" + "Include directory: ${SuiteSparse_INCLUDE_DIRS}\n" + "Library directory: ${SuiteSparse_LIBRARIES}\n\n") +endif() + +#set HAVE_SUITESPARSE for config.h +set(HAVE_SUITESPARSE ${SuiteSparse_FOUND}) +set(HAVE_UMFPACK ${SuiteSparse_UMFPACK_FOUND}) diff --git a/README.md b/README.md index 0517c77..43ee0c2 100644 --- a/README.md +++ b/README.md @@ -5,5 +5,5 @@ Standalone repository for the IDAKLU solver used in PyBaMM ## Installation ```bash -pip install pyidaklu +pip install pybammsolvers ``` diff --git a/install_KLU_Sundials.py b/install_KLU_Sundials.py new file mode 100755 index 0000000..c69d8fd --- /dev/null +++ b/install_KLU_Sundials.py @@ -0,0 +1,338 @@ +import os +import subprocess +import tarfile +import argparse +import platform +import hashlib +import shutil +import urllib.request +from os.path import join, isfile +from urllib.parse import urlparse +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import cpu_count +import pathlib + + +def build_solvers(): + SUITESPARSE_VERSION = "6.0.3" + SUNDIALS_VERSION = "6.5.0" + SUITESPARSE_URL = f"https://github.com/DrTimothyAldenDavis/SuiteSparse/archive/v{SUITESPARSE_VERSION}.tar.gz" + SUNDIALS_URL = f"https://github.com/LLNL/sundials/releases/download/v{SUNDIALS_VERSION}/sundials-{SUNDIALS_VERSION}.tar.gz" + SUITESPARSE_CHECKSUM = ( + "7111b505c1207f6f4bd0be9740d0b2897e1146b845d73787df07901b4f5c1fb7" + ) + SUNDIALS_CHECKSUM = "4e0b998dff292a2617e179609b539b511eb80836f5faacf800e688a886288502" + DEFAULT_INSTALL_DIR = str(pathlib.Path(__file__).parent.resolve() / ".idaklu") + + def safe_remove_dir(path): + if os.path.exists(path): + shutil.rmtree(path) + + def install_suitesparse(download_dir): + # The SuiteSparse KLU module has 4 dependencies: + # - suitesparseconfig + # - AMD + # - COLAMD + # - BTF + suitesparse_dir = f"SuiteSparse-{SUITESPARSE_VERSION}" + suitesparse_src = os.path.join(download_dir, suitesparse_dir) + print("-" * 10, "Building SuiteSparse_config", "-" * 40) + make_cmd = [ + "make", + "library", + ] + install_cmd = [ + "make", + f"-j{cpu_count()}", + "install", + ] + print("-" * 10, "Building SuiteSparse", "-" * 40) + # Set CMAKE_OPTIONS as environment variables to pass to the GNU Make command + env = os.environ.copy() + for libdir in ["SuiteSparse_config", "AMD", "COLAMD", "BTF", "KLU"]: + build_dir = os.path.join(suitesparse_src, libdir) + # We want to ensure that libsuitesparseconfig.dylib is not repeated in + # multiple paths at the time of wheel repair. Therefore, it should not be + # built with an RPATH since it is copied to the install prefix. + if libdir == "SuiteSparse_config": + # if in CI, set RPATH to the install directory for SuiteSparse_config + # dylibs to find libomp.dylib when repairing the wheel + if os.environ.get("CIBUILDWHEEL") == "1": + env["CMAKE_OPTIONS"] = ( + f"-DCMAKE_INSTALL_PREFIX={install_dir} -DCMAKE_INSTALL_RPATH={install_dir}/lib" + ) + else: + env["CMAKE_OPTIONS"] = f"-DCMAKE_INSTALL_PREFIX={install_dir}" + else: + # For AMD, COLAMD, BTF and KLU; do not set a BUILD RPATH but use an + # INSTALL RPATH in order to ensure that the dynamic libraries are found + # at runtime just once. Otherwise, delocate complains about multiple + # references to the SuiteSparse_config dynamic library (auditwheel does not). + env["CMAKE_OPTIONS"] = ( + f"-DCMAKE_INSTALL_PREFIX={install_dir} -DCMAKE_INSTALL_RPATH={install_dir}/lib -DCMAKE_INSTALL_RPATH_USE_LINK_PATH=FALSE -DCMAKE_BUILD_WITH_INSTALL_RPATH=FALSE" + ) + subprocess.run(make_cmd, cwd=build_dir, env=env, shell=True, check=True) + subprocess.run(install_cmd, cwd=build_dir, check=True) + + def install_sundials(download_dir, install_dir): + # Set install dir for SuiteSparse libs + # Ex: if install_dir -> "/usr/local/" then + # KLU_INCLUDE_DIR -> "/usr/local/include" + # KLU_LIBRARY_DIR -> "/usr/local/lib" + KLU_INCLUDE_DIR = os.path.join(install_dir, "include") + KLU_LIBRARY_DIR = os.path.join(install_dir, "lib") + cmake_args = [ + "-DENABLE_LAPACK=ON", + "-DSUNDIALS_INDEX_SIZE=32", + "-DEXAMPLES_ENABLE_C=OFF", + "-DEXAMPLES_ENABLE_CXX=OFF", + "-DEXAMPLES_INSTALL=OFF", + "-DENABLE_KLU=ON", + "-DENABLE_OPENMP=ON", + f"-DKLU_INCLUDE_DIR={KLU_INCLUDE_DIR}", + f"-DKLU_LIBRARY_DIR={KLU_LIBRARY_DIR}", + "-DCMAKE_INSTALL_PREFIX=" + install_dir, + # on macOS use fixed paths rather than rpath + "-DCMAKE_INSTALL_NAME_DIR=" + KLU_LIBRARY_DIR, + ] + + # try to find OpenMP on mac + if platform.system() == "Darwin": + # flags to find OpenMP on mac + if platform.processor() == "arm": + OpenMP_C_FLAGS = ( + "-Xpreprocessor -fopenmp -I/opt/homebrew/opt/libomp/include" + ) + OpenMP_C_LIB_NAMES = "omp" + OpenMP_omp_LIBRARY = "/opt/homebrew/opt/libomp/lib/libomp.dylib" + elif platform.processor() == "i386": + OpenMP_C_FLAGS = "-Xpreprocessor -fopenmp -I/usr/local/opt/libomp/include" + OpenMP_C_LIB_NAMES = "omp" + OpenMP_omp_LIBRARY = "/usr/local/opt/libomp/lib/libomp.dylib" + else: + raise NotImplementedError( + f"Unsupported processor architecture: {platform.processor()}. " + "Only 'arm' and 'i386' architectures are supported." + ) + + # Don't pass the following args to CMake when building wheels. We set a custom + # OpenMP installation for macOS wheels in the wheel build script. + # This is because we can't use Homebrew's OpenMP dylib due to the wheel + # repair process, where Homebrew binaries are not built for distribution and + # break MACOSX_DEPLOYMENT_TARGET. We use a custom OpenMP binary as described + # in CIBW_BEFORE_ALL in the wheel builder CI job. + # Check for CI environment variable to determine if we are building a wheel + if os.environ.get("CIBUILDWHEEL") != "1": + print("Using Homebrew OpenMP for macOS build") + cmake_args += [ + "-DOpenMP_C_FLAGS=" + OpenMP_C_FLAGS, + "-DOpenMP_C_LIB_NAMES=" + OpenMP_C_LIB_NAMES, + "-DOpenMP_omp_LIBRARY=" + OpenMP_omp_LIBRARY, + ] + + # SUNDIALS are built within download_dir 'build_sundials' in the PyBaMM root + # download_dir + build_dir = os.path.abspath(os.path.join(download_dir, "build_sundials")) + if not os.path.exists(build_dir): + print("\n-" * 10, "Creating build dir", "-" * 40) + os.makedirs(build_dir) + + sundials_src = f"../sundials-{SUNDIALS_VERSION}" + print("-" * 10, "Running CMake prepare", "-" * 40) + subprocess.run(["cmake", sundials_src, *cmake_args], cwd=build_dir, check=True) + + print("-" * 10, "Building SUNDIALS", "-" * 40) + make_cmd = ["make", f"-j{cpu_count()}", "install"] + subprocess.run(make_cmd, cwd=build_dir, check=True) + + def check_libraries_installed(install_dir): + # Define the directories to check for SUNDIALS and SuiteSparse libraries + lib_dirs = [install_dir] + + sundials_files = [ + "libsundials_idas", + "libsundials_sunlinsolklu", + "libsundials_sunlinsoldense", + "libsundials_sunlinsolspbcgs", + "libsundials_sunlinsollapackdense", + "libsundials_sunmatrixsparse", + "libsundials_nvecserial", + "libsundials_nvecopenmp", + ] + if platform.system() == "Linux": + sundials_files = [file + ".so" for file in sundials_files] + elif platform.system() == "Darwin": + sundials_files = [file + ".dylib" for file in sundials_files] + sundials_lib_found = True + # Check for SUNDIALS libraries in each directory + for lib_file in sundials_files: + file_found = False + for lib_dir in lib_dirs: + if isfile(join(lib_dir, "lib", lib_file)): + print(f"{lib_file} found in {lib_dir}.") + file_found = True + break + if not file_found: + print( + f"{lib_file} not found. Proceeding with SUNDIALS library installation." + ) + sundials_lib_found = False + break + + suitesparse_files = [ + "libsuitesparseconfig", + "libklu", + "libamd", + "libcolamd", + "libbtf", + ] + if platform.system() == "Linux": + suitesparse_files = [file + ".so" for file in suitesparse_files] + elif platform.system() == "Darwin": + suitesparse_files = [file + ".dylib" for file in suitesparse_files] + else: + raise NotImplementedError( + f"Unsupported operating system: {platform.system()}. This script currently supports only Linux and macOS." + ) + + suitesparse_lib_found = True + # Check for SuiteSparse libraries in each directory + for lib_file in suitesparse_files: + file_found = False + for lib_dir in lib_dirs: + if isfile(join(lib_dir, "lib", lib_file)): + print(f"{lib_file} found in {lib_dir}.") + file_found = True + break + if not file_found: + print( + f"{lib_file} not found. Proceeding with SuiteSparse library installation." + ) + suitesparse_lib_found = False + break + + return sundials_lib_found, suitesparse_lib_found + + def calculate_sha256(file_path): + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + # Read and update hash in chunks of 4K + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + return sha256_hash.hexdigest() + + def download_extract_library(url, expected_checksum, download_dir): + file_name = url.split("/")[-1] + file_path = os.path.join(download_dir, file_name) + + # Check if file already exists and validate checksum + if os.path.exists(file_path): + print(f"Validating checksum for {file_name}...") + actual_checksum = calculate_sha256(file_path) + print(f"Found {actual_checksum} against {expected_checksum}") + if actual_checksum == expected_checksum: + print(f"Checksum valid. Skipping download for {file_name}.") + # Extract the archive as the checksum is valid + with tarfile.open(file_path) as tar: + tar.extractall(download_dir) + return + else: + print(f"Checksum invalid. Redownloading {file_name}.") + + # Download and extract archive at url + parsed_url = urlparse(url) + if parsed_url.scheme not in ["http", "https"]: + raise ValueError( + f"Invalid URL scheme: {parsed_url.scheme}. Only HTTP and HTTPS are allowed." + ) + with urllib.request.urlopen(url) as response: + os.makedirs(download_dir, exist_ok=True) + with open(file_path, "wb") as out_file: + out_file.write(response.read()) + with tarfile.open(file_path) as tar: + tar.extractall(download_dir) + + def parallel_download(urls, download_dir): + # Use 2 processes for parallel downloading + with ThreadPoolExecutor(max_workers=len(urls)) as executor: + futures = [ + executor.submit( + download_extract_library, url, expected_checksum, download_dir + ) + for (url, expected_checksum) in urls + ] + for future in futures: + future.result() + + # First check requirements: make and cmake + try: + subprocess.run(["make", "--version"]) + except OSError as error: + raise RuntimeError("Make must be installed.") from error + try: + subprocess.run(["cmake", "--version"]) + except OSError as error: + raise RuntimeError("CMake must be installed.") from error + + # Build in parallel wherever possible + os.environ["CMAKE_BUILD_PARALLEL_LEVEL"] = str(cpu_count()) + + # Create download directory in PyBaMM dir + pybamm_dir = pathlib.Path(__file__).parent.resolve() + download_dir = str(pybamm_dir / "install_KLU_Sundials") + if not os.path.exists(download_dir): + os.makedirs(download_dir) + + # Get installation location + parser = argparse.ArgumentParser( + description="Download, compile and install Sundials and SuiteSparse." + ) + parser.add_argument( + "--force", + action="store_true", + help="Force installation even if libraries are already found. This will overwrite the pre-existing files.", + ) + parser.add_argument("--install-dir", type=str, default=DEFAULT_INSTALL_DIR) + args = parser.parse_args() + install_dir = ( + args.install_dir + if os.path.isabs(args.install_dir) + else os.path.join(pybamm_dir, args.install_dir) + ) + + if args.force: + print( + "The '--force' option is activated: installation will be forced, ignoring any existing libraries." + ) + safe_remove_dir(os.path.join(download_dir, "build_sundials")) + safe_remove_dir(os.path.join(download_dir, f"SuiteSparse-{SUITESPARSE_VERSION}")) + safe_remove_dir(os.path.join(download_dir, f"sundials-{SUNDIALS_VERSION}")) + sundials_found, suitesparse_found = False, False + else: + # Check whether the libraries are installed + sundials_found, suitesparse_found = check_libraries_installed(install_dir) + + # Determine which libraries to download based on whether they were found + if not sundials_found and not suitesparse_found: + # Both SUNDIALS and SuiteSparse are missing, download and install both + parallel_download( + [ + (SUITESPARSE_URL, SUITESPARSE_CHECKSUM), + (SUNDIALS_URL, SUNDIALS_CHECKSUM), + ], + download_dir, + ) + install_suitesparse(download_dir) + install_sundials(download_dir, install_dir) + else: + if not sundials_found: + # Only SUNDIALS is missing, download and install it + parallel_download([(SUNDIALS_URL, SUNDIALS_CHECKSUM)], download_dir) + install_sundials(download_dir, install_dir) + if not suitesparse_found: + # Only SuiteSparse is missing, download and install it + parallel_download([(SUITESPARSE_URL, SUITESPARSE_CHECKSUM)], download_dir) + install_suitesparse(download_dir) + +if __name__ == "__main__": + build_solvers() diff --git a/install_sundials.sh b/install_sundials.sh new file mode 100644 index 0000000..655d47f --- /dev/null +++ b/install_sundials.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# This script installs both SuiteSparse +# (https://people.engr.tamu.edu/davis/suitesparse.html) and SUNDIALS +# (https://computing.llnl.gov/projects/sundials) from source. For each +# two library: +# - Archive downloaded and source code extracted in current working +# directory. +# - Library is built and installed. +# +# Usage: ./install_sundials.sh suitesparse_version sundials_version + +function prepend_python_bin_dir_to_PATH { + python_bin_dir_cmd="print(os.path.split(sys.executable)[0])" + python_bin_dir=$(python -c "import sys, os;$python_bin_dir_cmd") + export PATH=$python_bin_dir:$PATH +} + +function download { + ROOT_ADDR=$1 + FILENAME=$2 + + wget -q $ROOT_ADDR/$FILENAME +} + +function extract { + tar -xf $1 +} + +SUITESPARSE_VERSION=6.0.3 +SUNDIALS_VERSION=6.5.0 + +SUITESPARSE_ROOT_ADDR=https://github.com/DrTimothyAldenDavis/SuiteSparse/archive +SUNDIALS_ROOT_ADDR=https://github.com/LLNL/sundials/releases/download/v$SUNDIALS_VERSION + +SUITESPARSE_ARCHIVE_NAME=v$SUITESPARSE_VERSION.tar.gz +SUNDIALS_ARCHIVE_NAME=sundials-$SUNDIALS_VERSION.tar.gz + +yum -y update +yum -y install wget +download $SUITESPARSE_ROOT_ADDR $SUITESPARSE_ARCHIVE_NAME +download $SUNDIALS_ROOT_ADDR $SUNDIALS_ARCHIVE_NAME +extract $SUITESPARSE_ARCHIVE_NAME +extract $SUNDIALS_ARCHIVE_NAME + +# Build in parallel wherever possible +export MAKEFLAGS="-j$(nproc)" +export CMAKE_BUILD_PARALLEL_LEVEL=$(nproc) + +### Compile and install SUITESPARSE ### +# SuiteSparse is required to compile SUNDIALS's +# KLU solver. + +SUITESPARSE_DIR=SuiteSparse-$SUITESPARSE_VERSION +for dir in SuiteSparse_config AMD COLAMD BTF KLU +do + make -C $SUITESPARSE_DIR/$dir library + make -C $SUITESPARSE_DIR/$dir install INSTALL=/usr +done + +### Compile and install SUNDIALS ### + +# Building cmake requires cmake >= 3.5 +python -m pip install cmake +prepend_python_bin_dir_to_PATH + +# Building SUNDIALS requires a BLAS library +yum -y install openblas-devel + +mkdir -p build_sundials +cd build_sundials +KLU_INCLUDE_DIR=/usr/local/include +KLU_LIBRARY_DIR=/usr/local/lib +SUNDIALS_DIR=sundials-$SUNDIALS_VERSION +cmake -DENABLE_LAPACK=ON\ + -DSUNDIALS_INDEX_SIZE=32\ + -DEXAMPLES_ENABLE:BOOL=OFF\ + -DENABLE_KLU=ON\ + -DENABLE_OPENMP=ON\ + -DKLU_INCLUDE_DIR=$KLU_INCLUDE_DIR\ + -DKLU_LIBRARY_DIR=$KLU_LIBRARY_DIR\ + -DCMAKE_INSTALL_PREFIX=/usr\ + ../$SUNDIALS_DIR +make install diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000..80c5ad7 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,134 @@ +import nox +import os +import sys +import warnings +import platform +from pathlib import Path + + +nox.options.default_venv_backend = "virtualenv" +nox.options.reuse_existing_virtualenvs = True +if sys.platform != "win32": + nox.options.sessions = ["idaklu-requires", "unit"] +else: + nox.options.sessions = ["unit"] + + +def set_iree_state(): + """ + Check if IREE is enabled and set the environment variable accordingly. + + Returns + ------- + str + "ON" if IREE is enabled, "OFF" otherwise. + + """ + state = "ON" if os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") == "ON" else "OFF" + if state == "ON": + if sys.platform == "win32": + warnings.warn( + ( + "IREE is not enabled on Windows yet. " + "Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." + ), + stacklevel=2, + ) + return "OFF" + if sys.platform == "darwin": + # iree-compiler is currently only available as a wheel on macOS 13 (or + # higher) and Python version 3.11 + mac_ver = int(platform.mac_ver()[0].split(".")[0]) + if (not sys.version_info[:2] == (3, 11)) or mac_ver < 13: + warnings.warn( + ( + "IREE is only supported on MacOS 13 (or higher) and Python" + "version 3.11. Setting PYBAMM_IDAKLU_EXPR_IREE=OFF." + ), + stacklevel=2, + ) + return "OFF" + return state + + +homedir = Path(__file__) +PYBAMM_ENV = { + "LD_LIBRARY_PATH": f"{homedir}/.idaklu/lib", + "PYTHONIOENCODING": "utf-8", + "MPLBACKEND": "Agg", + # Expression evaluators (...EXPR_CASADI cannot be fully disabled at this time) + "PYBAMM_IDAKLU_EXPR_CASADI": os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON"), + "PYBAMM_IDAKLU_EXPR_IREE": set_iree_state(), + "IREE_INDEX_URL": os.getenv( + "IREE_INDEX_URL", "https://iree.dev/pip-release-links.html" + ), +} +VENV_DIR = Path("./venv").resolve() + + +def set_environment_variables(env_dict, session): + """ + Sets environment variables for a nox Session object. + + Parameters + ----------- + session : nox.Session + The session to set the environment variables for. + env_dict : dict + A dictionary of environment variable names and values. + + """ + for key, value in env_dict.items(): + session.env[key] = value + + +@nox.session(name="idaklu-requires") +def run_pybamm_requires(session): + """Download, compile, and install the build-time requirements for Linux and macOS. Supports --install-dir for custom installation paths and --force to force installation.""" + set_environment_variables(PYBAMM_ENV, session=session) + if sys.platform != "win32": + session.run("python", "install_KLU_Sundials.py", *session.posargs) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON" and not os.path.exists( + "./iree" + ): + session.run( + "git", + "clone", + "--depth=1", + "--recurse-submodules", + "--shallow-submodules", + "--branch=candidate-20240507.886", + "https://github.com/openxla/iree", + "iree/", + external=True, + ) + with session.chdir("iree"): + session.run( + "git", + "submodule", + "update", + "--init", + "--recursive", + external=True, + ) + else: + session.error("nox -s idaklu-requires is only available on Linux & macOS.") + +@nox.session(name="unit") +def run_unit(session): + """Run the unit tests.""" + set_environment_variables(PYBAMM_ENV, session=session) + session.install("setuptools", silent=False) + session.install("casadi", silent=False) + session.install("-e", ".[dev]", silent=False) + if PYBAMM_ENV.get("PYBAMM_IDAKLU_EXPR_IREE") == "ON": + # See comments in 'dev' session + session.install( + "-e", + ".[iree]", + "--find-links", + PYBAMM_ENV.get("IREE_INDEX_URL"), + silent=False, + ) + session.run("pytest", "tests") + diff --git a/pybind11 b/pybind11 new file mode 160000 index 0000000..a2e59f0 --- /dev/null +++ b/pybind11 @@ -0,0 +1 @@ +Subproject commit a2e59f0e7065404b44dfe92a28aca47ba1378dc4 diff --git a/pyproject.toml b/pyproject.toml index cb792e5..b691ba3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,20 +1,35 @@ +[build-system] +requires = [ + "setuptools>=64", + "wheel", + # On Windows, use the CasADi vcpkg registry and CMake bundled from MSVC + "casadi>=3.6.5; platform_system!='Windows'", + # Note: the version of CasADi as a build-time dependency should be matched + # cross platforms, so updates to its minimum version here should be accompanied + # by a version bump in https://github.com/pybamm-team/casadi-vcpkg-registry. + "cmake; platform_system!='Windows'", +] +build-backend = "setuptools.build_meta" + [project] -name = "pyidaklu" +name = "pybammsolvers" description = "Python interface for the IDAKLU solver" -requires-python = ">=3.9" +requires-python = ">=3.9,<3.13" license = {file = "LICENSE"} dynamic = ["version", "readme"] -dependencies = [] +dependencies = [ + "casadi", +] [project.optional-dependencies] dev = [ "pytest", ] -[tools.setuptools.packages.find] +[tool.setuptools.packages.find] where = ["src"] -include = ["pyidaklu"] +include = ["pybammsolvers"] [tool.setuptools.dynamic] -version = {attr = "pyidaklu.version.__version__"} +version = {attr = "pybammsolvers.version.__version__"} readme = {file = ["README.md"], content-type = "text/markdown"} diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..bb9e68c --- /dev/null +++ b/setup.py @@ -0,0 +1,292 @@ +import os +import sys +import subprocess +from multiprocessing import cpu_count +from pathlib import Path +from platform import system +import wheel.bdist_wheel as orig + +from setuptools import setup, Extension +from setuptools.command.install import install +from setuptools.command.build_ext import build_ext + + +default_lib_dir = ( + "" if system() == "Windows" else str(Path(__file__).parent.resolve() / ".idaklu") +) + +# ---------- set environment variables for vcpkg on Windows ---------------------------- + + +def set_vcpkg_environment_variables(): + if not os.getenv("VCPKG_ROOT_DIR"): + raise OSError("Environment variable 'VCPKG_ROOT_DIR' is undefined.") + if not os.getenv("VCPKG_DEFAULT_TRIPLET"): + raise OSError("Environment variable 'VCPKG_DEFAULT_TRIPLET' is undefined.") + if not os.getenv("VCPKG_FEATURE_FLAGS"): + raise OSError("Environment variable 'VCPKG_FEATURE_FLAGS' is undefined.") + return ( + os.getenv("VCPKG_ROOT_DIR"), + os.getenv("VCPKG_DEFAULT_TRIPLET"), + os.getenv("VCPKG_FEATURE_FLAGS"), + ) + + +# ---------- CMakeBuild class (custom build_ext for IDAKLU target) --------------------- + + +class CMakeBuild(build_ext): + user_options = [ + *build_ext.user_options, + ("suitesparse-root=", None, "suitesparse source location"), + ("sundials-root=", None, "sundials source location"), + ] + + def initialize_options(self): + build_ext.initialize_options(self) + self.suitesparse_root = None + self.sundials_root = None + + def finalize_options(self): + build_ext.finalize_options(self) + # Determine the calling command to get the + # undefined options from. + # If build_ext was called directly then this + # doesn't matter. + try: + self.get_finalized_command("install", create=0) + calling_cmd = "install" + except AttributeError: + calling_cmd = "bdist_wheel" + self.set_undefined_options( + calling_cmd, + ("suitesparse_root", "suitesparse_root"), + ("sundials_root", "sundials_root"), + ) + if not self.suitesparse_root: + self.suitesparse_root = os.path.join(default_lib_dir) + if not self.sundials_root: + self.sundials_root = os.path.join(default_lib_dir) + + def get_build_directory(self): + # setuptools outputs object files in directory self.build_temp + # (typically build/temp.*). This is our CMake build directory. + # On Windows, setuptools is too smart and appends "Release" or + # "Debug" to self.build_temp. So in this case we want the + # build directory to be the parent directory. + if system() == "Windows": + return Path(self.build_temp).parents[0] + return self.build_temp + + def run(self): + if not self.extensions: + return + + # Build in parallel wherever possible + os.environ["CMAKE_BUILD_PARALLEL_LEVEL"] = str(cpu_count()) + + if system() == "Windows": + use_python_casadi = False + else: + use_python_casadi = True + + build_type = os.getenv("PYBAMM_CPP_BUILD_TYPE", "RELEASE") + idaklu_expr_casadi = os.getenv("PYBAMM_IDAKLU_EXPR_CASADI", "ON") + idaklu_expr_iree = os.getenv("PYBAMM_IDAKLU_EXPR_IREE", "OFF") + cmake_args = [ + f"-DCMAKE_BUILD_TYPE={build_type}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + "-DUSE_PYTHON_CASADI={}".format("TRUE" if use_python_casadi else "FALSE"), + f"-DPYBAMM_IDAKLU_EXPR_CASADI={idaklu_expr_casadi}", + f"-DPYBAMM_IDAKLU_EXPR_IREE={idaklu_expr_iree}", + ] + if self.suitesparse_root: + cmake_args.append( + f"-DSuiteSparse_ROOT={os.path.abspath(self.suitesparse_root)}" + ) + if self.sundials_root: + cmake_args.append(f"-DSUNDIALS_ROOT={os.path.abspath(self.sundials_root)}") + + build_dir = self.get_build_directory() + if not os.path.exists(build_dir): + os.makedirs(build_dir) + + # The CMakeError.log file is generated by cmake is the configure step + # encounters error. In the following the existence of this file is used + # to determine whether the cmake configure step went smoothly. + # So must make sure this file does not remain from a previous failed build. + if os.path.isfile(os.path.join(build_dir, "CMakeError.log")): + os.remove(os.path.join(build_dir, "CMakeError.log")) + + # ---------- configuration for vcpkg on Windows ---------------------------------------- + + build_env = os.environ + if os.getenv("PYBAMMSOLVERS_USE_VCPKG"): + ( + vcpkg_root_dir, + vcpkg_default_triplet, + vcpkg_feature_flags, + ) = set_vcpkg_environment_variables() + build_env["vcpkg_root_dir"] = vcpkg_root_dir + build_env["vcpkg_default_triplet"] = vcpkg_default_triplet + build_env["vcpkg_feature_flags"] = vcpkg_feature_flags + + # ---------- Run CMake and build IDAKLU module ----------------------------------------- + + cmake_list_dir = os.path.abspath(os.path.dirname(__file__)) + print("-" * 10, "Running CMake for IDAKLU solver", "-" * 40) + subprocess.run( + ["cmake", cmake_list_dir, *cmake_args], + cwd=build_dir, + env=build_env, + check=True, + ) + + if os.path.isfile(os.path.join(build_dir, "CMakeError.log")): + msg = ( + "cmake configuration steps encountered errors, and the IDAKLU module" + " could not be built. Make sure dependencies are correctly " + "installed. See " + "https://docs.pybamm.org/en/latest/source/user_guide/installation/install-from-source.html" + ) + raise RuntimeError(msg) + else: + print("-" * 10, "Building IDAKLU module", "-" * 40) + subprocess.run( + ["cmake", "--build", ".", "--config", "Release"], + cwd=build_dir, + env=build_env, + check=True, + ) + + # Move from build temp to final position + for ext in self.extensions: + self.move_output(ext) + + def move_output(self, ext): + # Copy built module to dist/ directory + build_temp = Path(self.build_temp).resolve() + # Get destination location + # self.get_ext_fullpath(ext.name) --> + # build/lib.linux-x86_64-3.5/idaklu.cpython-37m-x86_64-linux-gnu.so + dest_path = Path(self.get_ext_fullpath(ext.name)).resolve() + source_path = build_temp / os.path.basename(self.get_ext_filename(ext.name)) + dest_directory = dest_path.parents[0] + dest_directory.mkdir(parents=True, exist_ok=True) + self.copy_file(source_path, dest_path) + + +# ---------- end of CMake steps -------------------------------------------------------- + +class CustomInstall(install): + """A custom install command to add 2 build options""" + + user_options = [ + *install.user_options, + ("suitesparse-root=", None, "suitesparse source location"), + ("sundials-root=", None, "sundials source location"), + ] + + def initialize_options(self): + install.initialize_options(self) + self.suitesparse_root = None + self.sundials_root = None + + def finalize_options(self): + install.finalize_options(self) + if not self.suitesparse_root: + self.suitesparse_root = default_lib_dir + if not self.sundials_root: + self.sundials_root = default_lib_dir + + def run(self): + install.run(self) + + +# ---------- Custom class for building wheels ------------------------------------------ + + +class bdist_wheel(orig.bdist_wheel): + """A custom install command to add 2 build options""" + + user_options = [ + *orig.bdist_wheel.user_options, + ("suitesparse-root=", None, "suitesparse source location"), + ("sundials-root=", None, "sundials source location"), + ] + + def initialize_options(self): + orig.bdist_wheel.initialize_options(self) + self.suitesparse_root = None + self.sundials_root = None + + def finalize_options(self): + orig.bdist_wheel.finalize_options(self) + if not self.suitesparse_root: + self.suitesparse_root = default_lib_dir + if not self.sundials_root: + self.sundials_root = default_lib_dir + + def run(self): + orig.bdist_wheel.run(self) + + +idaklu_ext = Extension( + name="pybammsolvers.idaklu", + # The sources list should mirror the list in CMakeLists.txt + sources=[ + "src/pybammsolvers/idaklu_source/Expressions/Expressions.hpp", + "src/pybammsolvers/idaklu_source/Expressions/Base/Expression.hpp", + "src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSet.hpp", + "src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionTypes.hpp", + "src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSparsity.hpp", + "src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp", + "src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.hpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.hpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.hpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp", + "src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp", + "src/pybammsolvers/idaklu_source/idaklu_solver.hpp", + "src/pybammsolvers/idaklu_source/IDAKLUSolver.cpp", + "src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp", + "src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.cpp", + "src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.hpp", + "src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl", + "src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.hpp", + "src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.cpp", + "src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.hpp", + "src/pybammsolvers/idaklu_source/sundials_functions.inl", + "src/pybammsolvers/idaklu_source/sundials_functions.hpp", + "src/pybammsolvers/idaklu_source/IdakluJax.cpp", + "src/pybammsolvers/idaklu_source/IdakluJax.hpp", + "src/pybammsolvers/idaklu_source/common.hpp", + "src/pybammsolvers/idaklu_source/common.cpp", + "src/pybammsolvers/idaklu_source/Solution.cpp", + "src/pybammsolvers/idaklu_source/Solution.hpp", + "src/pybammsolvers/idaklu_source/SolutionData.cpp", + "src/pybammsolvers/idaklu_source/SolutionData.hpp", + "src/pybammsolvers/idaklu_source/observe.cpp", + "src/pybammsolvers/idaklu_source/observe.hpp", + "src/pybammsolvers/idaklu_source/Options.hpp", + "src/pybammsolvers/idaklu_source/Options.cpp", + "src/pybammsolvers/idaklu.cpp", + ], +) +ext_modules = [idaklu_ext] + +# Project metadata was moved to pyproject.toml (which is read by pip). However, custom +# build commands and setuptools extension modules are still defined here. +setup( + # silence "Package would be ignored" warnings + include_package_data=True, + ext_modules=ext_modules, + cmdclass={ + "build_ext": CMakeBuild, + "bdist_wheel": bdist_wheel, + "install": CustomInstall, + }, +) diff --git a/src/pybammsolvers/__init__.py b/src/pybammsolvers/__init__.py new file mode 100644 index 0000000..2b93acb --- /dev/null +++ b/src/pybammsolvers/__init__.py @@ -0,0 +1,6 @@ +import importlib.util as il + + +idaklu_spec = il.find_spec("pybammsolvers.idaklu") +idaklu = il.module_from_spec(idaklu_spec) +idaklu_spec.loader.exec_module(idaklu) diff --git a/src/pybammsolvers/idaklu.cpp b/src/pybammsolvers/idaklu.cpp new file mode 100644 index 0000000..b4ebc13 --- /dev/null +++ b/src/pybammsolvers/idaklu.cpp @@ -0,0 +1,199 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "idaklu_source/idaklu_solver.hpp" +#include "idaklu_source/observe.hpp" +#include "idaklu_source/IDAKLUSolverGroup.hpp" +#include "idaklu_source/IdakluJax.hpp" +#include "idaklu_source/common.hpp" +#include "idaklu_source/Expressions/Casadi/CasadiFunctions.hpp" + +#ifdef IREE_ENABLE +#include "idaklu_source/Expressions/IREE/IREEFunctions.hpp" +#endif + + +casadi::Function generate_casadi_function(const std::string &data) +{ + return casadi::Function::deserialize(data); +} + +namespace py = pybind11; + +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); + +PYBIND11_MODULE(idaklu, m) +{ + m.doc() = "sundials solvers"; // optional module docstring + + py::bind_vector>(m, "VectorNdArray"); + py::bind_vector>(m, "VectorRealtypeNdArray"); + py::bind_vector>(m, "VectorSolution"); + + py::class_(m, "IDAKLUSolverGroup") + .def("solve", &IDAKLUSolverGroup::solve, + "perform a solve", + py::arg("t_eval"), + py::arg("t_interp"), + py::arg("y0"), + py::arg("yp0"), + py::arg("inputs"), + py::return_value_policy::take_ownership); + + m.def("create_casadi_solver_group", &create_idaklu_solver_group, + "Create a group of casadi idaklu solver objects", + py::arg("number_of_states"), + py::arg("number_of_parameters"), + py::arg("rhs_alg"), + py::arg("jac_times_cjmass"), + py::arg("jac_times_cjmass_colptrs"), + py::arg("jac_times_cjmass_rowvals"), + py::arg("jac_times_cjmass_nnz"), + py::arg("jac_bandwidth_lower"), + py::arg("jac_bandwidth_upper"), + py::arg("jac_action"), + py::arg("mass_action"), + py::arg("sens"), + py::arg("events"), + py::arg("number_of_events"), + py::arg("rhs_alg_id"), + py::arg("atol"), + py::arg("rtol"), + py::arg("inputs"), + py::arg("var_fcns"), + py::arg("dvar_dy_fcns"), + py::arg("dvar_dp_fcns"), + py::arg("options"), + py::return_value_policy::take_ownership); + + m.def("observe", &observe, + "Observe variables", + py::arg("ts"), + py::arg("ys"), + py::arg("inputs"), + py::arg("funcs"), + py::arg("is_f_contiguous"), + py::arg("shape"), + py::return_value_policy::take_ownership); + + m.def("observe_hermite_interp", &observe_hermite_interp, + "Observe and Hermite interpolate variables", + py::arg("t_interp"), + py::arg("ts"), + py::arg("ys"), + py::arg("yps"), + py::arg("inputs"), + py::arg("funcs"), + py::arg("shape"), + py::return_value_policy::take_ownership); + +#ifdef IREE_ENABLE + m.def("create_iree_solver_group", &create_idaklu_solver_group, + "Create a group of iree idaklu solver objects", + py::arg("number_of_states"), + py::arg("number_of_parameters"), + py::arg("rhs_alg"), + py::arg("jac_times_cjmass"), + py::arg("jac_times_cjmass_colptrs"), + py::arg("jac_times_cjmass_rowvals"), + py::arg("jac_times_cjmass_nnz"), + py::arg("jac_bandwidth_lower"), + py::arg("jac_bandwidth_upper"), + py::arg("jac_action"), + py::arg("mass_action"), + py::arg("sens"), + py::arg("events"), + py::arg("number_of_events"), + py::arg("rhs_alg_id"), + py::arg("atol"), + py::arg("rtol"), + py::arg("inputs"), + py::arg("var_fcns"), + py::arg("dvar_dy_fcns"), + py::arg("dvar_dp_fcns"), + py::arg("options"), + py::return_value_policy::take_ownership); +#endif + + m.def("generate_function", &generate_casadi_function, + "Generate a casadi function", + py::arg("string"), + py::return_value_policy::take_ownership); + + // IdakluJax interface routines + py::class_(m, "IdakluJax") + .def( + "register_callback_eval", + &IdakluJax::register_callback_eval, + "Register a callback for function evaluation", + py::arg("callback") + ) + .def( + "register_callback_jvp", + &IdakluJax::register_callback_jvp, + "Register a callback for JVP evaluation", + py::arg("callback") + ) + .def( + "register_callback_vjp", + &IdakluJax::register_callback_vjp, + "Register a callback for the VJP evaluation", + py::arg("callback") + ) + .def( + "register_callbacks", + &IdakluJax::register_callbacks, + "Register callbacks for function evaluation, JVP evaluation, and VJP evaluation", + py::arg("callback_eval"), + py::arg("callback_jvp"), + py::arg("callback_vjp") + ) + .def( + "get_index", + &IdakluJax::get_index, + "Get the index of the JAXified instance" + ); + m.def( + "create_idaklu_jax", + &create_idaklu_jax, + "Create an idaklu jax object" + ); + m.def( + "registrations", + &Registrations + ); + + py::class_(m, "Function"); + +#ifdef IREE_ENABLE + py::class_(m, "IREEBaseFunctionType") + .def(py::init<>()) + .def_readwrite("mlir", &IREEBaseFunctionType::mlir) + .def_readwrite("kept_var_idx", &IREEBaseFunctionType::kept_var_idx) + .def_readwrite("nnz", &IREEBaseFunctionType::nnz) + .def_readwrite("numel", &IREEBaseFunctionType::numel) + .def_readwrite("col", &IREEBaseFunctionType::col) + .def_readwrite("row", &IREEBaseFunctionType::row) + .def_readwrite("pytree_shape", &IREEBaseFunctionType::pytree_shape) + .def_readwrite("pytree_sizes", &IREEBaseFunctionType::pytree_sizes) + .def_readwrite("n_args", &IREEBaseFunctionType::n_args); +#endif + + py::class_(m, "solution") + .def_readwrite("t", &Solution::t) + .def_readwrite("y", &Solution::y) + .def_readwrite("yp", &Solution::yp) + .def_readwrite("yS", &Solution::yS) + .def_readwrite("ypS", &Solution::ypS) + .def_readwrite("y_term", &Solution::y_term) + .def_readwrite("flag", &Solution::flag); +} diff --git a/src/pybammsolvers/idaklu_source/Expressions/Base/Expression.hpp b/src/pybammsolvers/idaklu_source/Expressions/Base/Expression.hpp new file mode 100644 index 0000000..370d2c7 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/Base/Expression.hpp @@ -0,0 +1,69 @@ +#ifndef PYBAMM_EXPRESSION_HPP +#define PYBAMM_EXPRESSION_HPP + +#include "ExpressionTypes.hpp" +#include "../../common.hpp" +#include "../../Options.hpp" +#include +#include + +class Expression { +public: // method declarations + /** + * @brief Constructor + */ + Expression() = default; + + /** + * @brief Evaluation operator (for use after setting input and output data references) + */ + virtual void operator()() = 0; + + /** + * @brief Evaluation operator (supplying data references) + */ + virtual void operator()( + const std::vector& inputs, + const std::vector& results) = 0; + + /** + * @brief The maximum number of elements returned by the k'th output + * + * This is used to allocate memory for the output of the function and usually (but + * not always) corresponds to the number of non-zero elements (NNZ). + */ + virtual expr_int out_shape(int k) = 0; + + /** + * @brief Return the number of non-zero elements for the function output + */ + virtual expr_int nnz() = 0; + + /** + * @brief Return the number of non-zero elements for the function output + */ + virtual expr_int nnz_out() = 0; + + /** + * @brief Returns row indices in COO format (where the output data represents sparse matrix elements) + */ + virtual const std::vector& get_row() = 0; + + /** + * @brief Returns column indices in COO format (where the output data represents sparse matrix elements) + */ + virtual const std::vector& get_col() = 0; + +public: // data members + /** + * @brief Vector of pointers to the input data + */ + std::vector m_arg; // cppcheck-suppress unusedStructMember + + /** + * @brief Vector of pointers to the output data + */ + std::vector m_res; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_EXPRESSION_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSet.hpp b/src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSet.hpp new file mode 100644 index 0000000..13c746a --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionSet.hpp @@ -0,0 +1,86 @@ +#ifndef PYBAMM_IDAKLU_EXPRESSION_SET_HPP +#define PYBAMM_IDAKLU_EXPRESSION_SET_HPP + +#include "ExpressionTypes.hpp" +#include "Expression.hpp" +#include "../../common.hpp" +#include "../../Options.hpp" +#include + +template +class ExpressionSet +{ +public: + + /** + * @brief Constructor + */ + ExpressionSet( + Expression* rhs_alg, + Expression* jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, // cppcheck-suppress unusedStructMember + const np_array_int &jac_times_cjmass_colptrs_arg, // cppcheck-suppress unusedStructMember + const int inputs_length, + Expression* jac_action, + Expression* mass_action, + Expression* sens, + Expression* events, + const int n_s, + const int n_e, + const int n_p, + const SetupOptions& options) + : number_of_states(n_s), + number_of_events(n_e), + number_of_parameters(n_p), + number_of_nnz(jac_times_cjmass_nnz), + jac_bandwidth_lower(jac_bandwidth_lower), + jac_bandwidth_upper(jac_bandwidth_upper), + rhs_alg(rhs_alg), + jac_times_cjmass(jac_times_cjmass), + jac_action(jac_action), + mass_action(mass_action), + sens(sens), + events(events), + tmp_state_vector(number_of_states), + tmp_sparse_jacobian_data(jac_times_cjmass_nnz), + setup_opts(options) + {}; + + int number_of_states; + int number_of_parameters; + int number_of_events; + int number_of_nnz; + int jac_bandwidth_lower; + int jac_bandwidth_upper; + + Expression *rhs_alg = nullptr; + Expression *jac_times_cjmass = nullptr; + Expression *jac_action = nullptr; + Expression *mass_action = nullptr; + Expression *sens = nullptr; + Expression *events = nullptr; + + // `cppcheck-suppress unusedStructMember` is used because codacy reports + // these members as unused, but they are inherited through variadics + std::vector var_fcns; // cppcheck-suppress unusedStructMember + std::vector dvar_dy_fcns; // cppcheck-suppress unusedStructMember + std::vector dvar_dp_fcns; // cppcheck-suppress unusedStructMember + + std::vector jac_times_cjmass_rowvals; // cppcheck-suppress unusedStructMember + std::vector jac_times_cjmass_colptrs; // cppcheck-suppress unusedStructMember + std::vector inputs; // cppcheck-suppress unusedStructMember + + SetupOptions setup_opts; + + virtual realtype *get_tmp_state_vector() = 0; + virtual realtype *get_tmp_sparse_jacobian_data() = 0; + +protected: + std::vector tmp_state_vector; + std::vector tmp_sparse_jacobian_data; +}; + +#endif // PYBAMM_IDAKLU_EXPRESSION_SET_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionTypes.hpp b/src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionTypes.hpp new file mode 100644 index 0000000..c8d690c --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/Base/ExpressionTypes.hpp @@ -0,0 +1,6 @@ +#ifndef PYBAMM_EXPRESSION_TYPES_HPP +#define PYBAMM_EXPRESSION_TYPES_HPP + +using expr_int = long long int; + +#endif // PYBAMM_EXPRESSION_TYPES_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp b/src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp new file mode 100644 index 0000000..9c4ce0f --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/Casadi/CasadiFunctions.cpp @@ -0,0 +1,76 @@ +#include "CasadiFunctions.hpp" +#include + +CasadiFunction::CasadiFunction(const BaseFunctionType &f) : Expression(), m_func(f) +{ + DEBUG("CasadiFunction constructor: " << m_func.name()); + + size_t sz_arg; + size_t sz_res; + size_t sz_iw; + size_t sz_w; + m_func.sz_work(sz_arg, sz_res, sz_iw, sz_w); + + int nnz = (sz_res>0) ? m_func.nnz_out() : 0; // cppcheck-suppress unreadVariable + DEBUG("name = "<< m_func.name() << " arg = " << sz_arg << " res = " + << sz_res << " iw = " << sz_iw << " w = " << sz_w << " nnz = " << nnz); + + m_arg.resize(sz_arg, nullptr); + m_res.resize(sz_res, nullptr); + m_iw.resize(sz_iw, 0); + m_w.resize(sz_w, 0); + + if (m_func.n_out() > 0) { + casadi::Sparsity casadi_sparsity = m_func.sparsity_out(0); + m_rows = casadi_sparsity.get_row(); + m_cols = casadi_sparsity.get_col(); + } +} + +// only call this once m_arg and m_res have been set appropriately +void CasadiFunction::operator()() +{ + DEBUG("CasadiFunction operator(): " << m_func.name()); + int mem = m_func.checkout(); + m_func(m_arg.data(), m_res.data(), m_iw.data(), m_w.data(), mem); + m_func.release(mem); +} + +expr_int CasadiFunction::out_shape(int k) { + DEBUG("CasadiFunctions out_shape(): " << m_func.name() << " " << m_func.nnz_out()); + return static_cast(m_func.nnz_out()); +} + +expr_int CasadiFunction::nnz() { + DEBUG("CasadiFunction nnz(): " << m_func.name() << " " << static_cast(m_func.nnz_out())); + return static_cast(m_func.nnz_out()); +} + +expr_int CasadiFunction::nnz_out() { + DEBUG("CasadiFunction nnz_out(): " << m_func.name() << " " << static_cast(m_func.nnz_out())); + return static_cast(m_func.nnz_out()); +} + +const std::vector& CasadiFunction::get_row() { + DEBUG("CasadiFunction get_row(): " << m_func.name()); + return m_rows; +} + +const std::vector& CasadiFunction::get_col() { + DEBUG("CasadiFunction get_col(): " << m_func.name()); + return m_cols; +} + +void CasadiFunction::operator()(const std::vector& inputs, + const std::vector& results) +{ + DEBUG("CasadiFunction operator() with inputs and results: " << m_func.name()); + + // Set-up input arguments, provide result vector, then execute function + // Example call: fcn({in1, in2, in3}, {out1}) + for(size_t k=0; k +#include +#include +#include + +/** + * @brief Class for handling individual casadi functions + */ +class CasadiFunction : public Expression +{ +public: + + typedef casadi::Function BaseFunctionType; + + /** + * @brief Constructor + */ + explicit CasadiFunction(const BaseFunctionType &f); + + // Method overrides + void operator()() override; + void operator()(const std::vector& inputs, + const std::vector& results) override; + expr_int out_shape(int k) override; + expr_int nnz() override; + expr_int nnz_out() override; + const std::vector& get_row() override; + const std::vector& get_col() override; + +public: + /* + * @brief Casadi function + */ + BaseFunctionType m_func; + +private: + std::vector m_iw; // cppcheck-suppress unusedStructMember + std::vector m_w; // cppcheck-suppress unusedStructMember + std::vector m_rows; // cppcheck-suppress unusedStructMember + std::vector m_cols; // cppcheck-suppress unusedStructMember +}; + +/** + * @brief Class for handling casadi functions + */ +class CasadiFunctions : public ExpressionSet +{ +public: + + typedef CasadiFunction::BaseFunctionType BaseFunctionType; // expose typedef in class + + /** + * @brief Create a new CasadiFunctions object + */ + CasadiFunctions( + const BaseFunctionType &rhs_alg, + const BaseFunctionType &jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, + const np_array_int &jac_times_cjmass_colptrs_arg, + const int inputs_length, + const BaseFunctionType &jac_action, + const BaseFunctionType &mass_action, + const BaseFunctionType &sens, + const BaseFunctionType &events, + const int n_s, + const int n_e, + const int n_p, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + const SetupOptions& setup_opts + ) : + rhs_alg_casadi(rhs_alg), + jac_times_cjmass_casadi(jac_times_cjmass), + jac_action_casadi(jac_action), + mass_action_casadi(mass_action), + sens_casadi(sens), + events_casadi(events), + ExpressionSet( + static_cast(&rhs_alg_casadi), + static_cast(&jac_times_cjmass_casadi), + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals_arg, + jac_times_cjmass_colptrs_arg, + inputs_length, + static_cast(&jac_action_casadi), + static_cast(&mass_action_casadi), + static_cast(&sens_casadi), + static_cast(&events_casadi), + n_s, n_e, n_p, + setup_opts) + { + // convert BaseFunctionType list to CasadiFunction list + // NOTE: You must allocate ALL std::vector elements before taking references + for (auto& var : var_fcns) + var_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < var_fcns_casadi.size(); k++) + ExpressionSet::var_fcns.push_back(&this->var_fcns_casadi[k]); + + for (auto& var : dvar_dy_fcns) + dvar_dy_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < dvar_dy_fcns_casadi.size(); k++) + this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_casadi[k]); + + for (auto& var : dvar_dp_fcns) + dvar_dp_fcns_casadi.push_back(CasadiFunction(*var)); + for (int k = 0; k < dvar_dp_fcns_casadi.size(); k++) + this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_casadi[k]); + + // copy across numpy array values + const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); + jac_times_cjmass_rowvals.resize(n_row_vals); + for (int i = 0; i < n_row_vals; i++) { + jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); + jac_times_cjmass_colptrs.resize(n_col_ptrs); + for (int i = 0; i < n_col_ptrs; i++) { + jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + + inputs.resize(inputs_length); + } + + CasadiFunction rhs_alg_casadi; + CasadiFunction jac_times_cjmass_casadi; + CasadiFunction jac_action_casadi; + CasadiFunction mass_action_casadi; + CasadiFunction sens_casadi; + CasadiFunction events_casadi; + + std::vector var_fcns_casadi; + std::vector dvar_dy_fcns_casadi; + std::vector dvar_dp_fcns_casadi; + + realtype* get_tmp_state_vector() override { + return tmp_state_vector.data(); + } + realtype* get_tmp_sparse_jacobian_data() override { + return tmp_sparse_jacobian_data.data(); + } +}; + +#endif // PYBAMM_IDAKLU_CASADI_FUNCTIONS_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/Expressions.hpp b/src/pybammsolvers/idaklu_source/Expressions/Expressions.hpp new file mode 100644 index 0000000..70380ea --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/Expressions.hpp @@ -0,0 +1,6 @@ +#ifndef PYBAMM_IDAKLU_EXPRESSIONS_HPP +#define PYBAMM_IDAKLU_EXPRESSIONS_HPP + +#include "Base/ExpressionSet.hpp" + +#endif // PYBAMM_IDAKLU_EXPRESSIONS_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp new file mode 100644 index 0000000..d2ba7e4 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEBaseFunction.hpp @@ -0,0 +1,27 @@ +#ifndef PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP +#define PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP + +#include +#include + +/* + * @brief Function definition passed from PyBaMM + */ +class IREEBaseFunctionType +{ +public: // methods + const std::string& get_mlir() const { return mlir; } + +public: // data members + std::string mlir; // cppcheck-suppress unusedStructMember + std::vector kept_var_idx; // cppcheck-suppress unusedStructMember + expr_int nnz; // cppcheck-suppress unusedStructMember + expr_int numel; // cppcheck-suppress unusedStructMember + std::vector col; // cppcheck-suppress unusedStructMember + std::vector row; // cppcheck-suppress unusedStructMember + std::vector pytree_shape; // cppcheck-suppress unusedStructMember + std::vector pytree_sizes; // cppcheck-suppress unusedStructMember + expr_int n_args; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_IDAKLU_IREE_BASE_FUNCTION_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp new file mode 100644 index 0000000..bcdae5e --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunction.hpp @@ -0,0 +1,59 @@ +#ifndef PYBAMM_IDAKLU_IREE_FUNCTION_HPP +#define PYBAMM_IDAKLU_IREE_FUNCTION_HPP + +#include "../../Options.hpp" +#include "../Expressions.hpp" +#include +#include "iree_jit.hpp" +#include "IREEBaseFunction.hpp" + +/** + * @brief Class for handling individual iree functions + */ +class IREEFunction : public Expression +{ +public: + typedef IREEBaseFunctionType BaseFunctionType; + + /* + * @brief Constructor + */ + explicit IREEFunction(const BaseFunctionType &f); + + // Method overrides + void operator()() override; + void operator()(const std::vector& inputs, + const std::vector& results) override; + expr_int out_shape(int k) override; + expr_int nnz() override; + expr_int nnz_out() override; + const std::vector& get_col() override; + const std::vector& get_row() override; + + /* + * @brief Evaluate the MLIR function + */ + void evaluate(); + + /* + * @brief Evaluate the MLIR function + * @param n_outputs The number of outputs to return + */ + void evaluate(int n_outputs); + +public: + std::unique_ptr session; + std::vector> result; // cppcheck-suppress unusedStructMember + std::vector> input_shape; // cppcheck-suppress unusedStructMember + std::vector> output_shape; // cppcheck-suppress unusedStructMember + std::vector> input_data; // cppcheck-suppress unusedStructMember + + BaseFunctionType m_func; // cppcheck-suppress unusedStructMember + std::string module_name; // cppcheck-suppress unusedStructMember + std::string function_name; // cppcheck-suppress unusedStructMember + std::vector m_arg_argno; // cppcheck-suppress unusedStructMember + std::vector m_arg_argix; // cppcheck-suppress unusedStructMember + std::vector numel; // cppcheck-suppress unusedStructMember +}; + +#endif // PYBAMM_IDAKLU_IREE_FUNCTION_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp new file mode 100644 index 0000000..3bde647 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/IREE/IREEFunctions.cpp @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include + +#include "IREEFunctions.hpp" +#include "iree_jit.hpp" +#include "ModuleParser.hpp" + +IREEFunction::IREEFunction(const BaseFunctionType &f) : Expression(), m_func(f) +{ + DEBUG("IreeFunction constructor"); + const std::string& mlir = f.get_mlir(); + + // Parse IREE (MLIR) function string + if (mlir.size() == 0) { + DEBUG("Empty function --- skipping..."); + return; + } + + // Parse MLIR for module name, input and output shapes + ModuleParser parser(mlir); + module_name = parser.getModuleName(); + function_name = parser.getFunctionName(); + input_shape = parser.getInputShape(); + output_shape = parser.getOutputShape(); + + DEBUG("Compiling module: '" << module_name << "'"); + const char* device_uri = "local-sync"; + session = std::make_unique(device_uri, mlir); + DEBUG("compile complete."); + // Create index vectors into m_arg + // This is required since Jax expands input arguments through PyTrees, which need to + // be remapped to the corresponding expression call. For example: + // fcn(t, y, inputs, cj) with inputs = [[in1], [in2], [in3]] + // will produce a function with six inputs; we therefore need to be able to map + // arguments to their 1) corresponding input argument, and 2) the correct position + // within that argument. + m_arg_argno.clear(); + m_arg_argix.clear(); + int current_element = 0; + for (int i=0; i 2) || + ((input_shape[j].size() == 2) && (input_shape[j][1] > 1)) + ) { + std::cerr << "Unsupported input shape: " << input_shape[j].size() << " ["; + for (int k=0; k {res0} signature (i.e. x and z are reduced out) + // with kept_var_idx = [1] + // + // *********************************************************************************** + + DEBUG("Copying inputs, shape " << input_shape.size() << " - " << m_func.kept_var_idx.size()); + for (int j=0; j 1) { + // Index into argument using appropriate shape + for(int k=0; k(m_arg[m_arg_from][m_arg_argix[mlir_arg]+k]); + } + } else { + // Copy the entire vector + for(int k=0; k(m_arg[m_arg_from][k]); + } + } + } + + // Call the 'main' function of the module + const std::string mlir = m_func.get_mlir(); + DEBUG("Calling function '" << function_name << "'"); + auto status = session->iree_runtime_exec(function_name, input_shape, input_data, result); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + std::cerr << "MLIR: " << mlir.substr(0,1000) << std::endl; + throw std::runtime_error("Execution failed"); + } + + // Copy results to output array + for(size_t k=0; k(result[k][j]); + } + } + + DEBUG("IreeFunction operator() complete"); +} + +expr_int IREEFunction::out_shape(int k) { + DEBUG("IreeFunction nnz(" << k << "): " << m_func.nnz); + auto elements = 1; + for (auto i : output_shape[k]) { + elements *= i; + } + return elements; +} + +expr_int IREEFunction::nnz() { + DEBUG("IreeFunction nnz: " << m_func.nnz); + return nnz_out(); +} + +expr_int IREEFunction::nnz_out() { + DEBUG("IreeFunction nnz_out" << m_func.nnz); + return m_func.nnz; +} + +const std::vector& IREEFunction::get_row() { + DEBUG("IreeFunction get_row" << m_func.row.size()); + return m_func.row; +} + +const std::vector& IREEFunction::get_col() { + DEBUG("IreeFunction get_col" << m_func.col.size()); + return m_func.col; +} + +void IREEFunction::operator()(const std::vector& inputs, + const std::vector& results) +{ + DEBUG("IreeFunction operator() with inputs and results"); + // Set-up input arguments, provide result vector, then execute function + // Example call: fcn({in1, in2, in3}, {out1}) + ASSERT(inputs.size() == m_func.n_args); + for(size_t k=0; k +#include "iree_jit.hpp" +#include "IREEFunction.hpp" + +/** + * @brief Class for handling iree functions + */ +class IREEFunctions : public ExpressionSet +{ +public: + std::unique_ptr iree_compiler; + + typedef IREEFunction::BaseFunctionType BaseFunctionType; // expose typedef in class + + int iree_init_status; + + int iree_init(const std::string& device_uri, const std::string& target_backends) { + // Initialise IREE + DEBUG("IREEFunctions: Initialising IREECompiler"); + iree_compiler = std::make_unique(device_uri.c_str()); + + int iree_argc = 2; + std::string target_backends_str = "--iree-hal-target-backends=" + target_backends; + const char* iree_argv[2] = {"iree", target_backends_str.c_str()}; + iree_compiler->init(iree_argc, iree_argv); + DEBUG("IREEFunctions: Initialised IREECompiler"); + return 0; + } + + int iree_init() { + return iree_init("local-sync", "llvm-cpu"); + } + + + /** + * @brief Create a new IREEFunctions object + */ + IREEFunctions( + const BaseFunctionType &rhs_alg, + const BaseFunctionType &jac_times_cjmass, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const np_array_int &jac_times_cjmass_rowvals_arg, + const np_array_int &jac_times_cjmass_colptrs_arg, + const int inputs_length, + const BaseFunctionType &jac_action, + const BaseFunctionType &mass_action, + const BaseFunctionType &sens, + const BaseFunctionType &events, + const int n_s, + const int n_e, + const int n_p, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + const SetupOptions& setup_opts + ) : + iree_init_status(iree_init()), + rhs_alg_iree(rhs_alg), + jac_times_cjmass_iree(jac_times_cjmass), + jac_action_iree(jac_action), + mass_action_iree(mass_action), + sens_iree(sens), + events_iree(events), + ExpressionSet( + static_cast(&rhs_alg_iree), + static_cast(&jac_times_cjmass_iree), + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals_arg, + jac_times_cjmass_colptrs_arg, + inputs_length, + static_cast(&jac_action_iree), + static_cast(&mass_action_iree), + static_cast(&sens_iree), + static_cast(&events_iree), + n_s, n_e, n_p, + setup_opts) + { + // convert BaseFunctionType list to IREEFunction list + // NOTE: You must allocate ALL std::vector elements before taking references + for (auto& var : var_fcns) + var_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < var_fcns_iree.size(); k++) + ExpressionSet::var_fcns.push_back(&this->var_fcns_iree[k]); + + for (auto& var : dvar_dy_fcns) + dvar_dy_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < dvar_dy_fcns_iree.size(); k++) + this->dvar_dy_fcns.push_back(&this->dvar_dy_fcns_iree[k]); + + for (auto& var : dvar_dp_fcns) + dvar_dp_fcns_iree.push_back(IREEFunction(*var)); + for (int k = 0; k < dvar_dp_fcns_iree.size(); k++) + this->dvar_dp_fcns.push_back(&this->dvar_dp_fcns_iree[k]); + + // copy across numpy array values + const int n_row_vals = jac_times_cjmass_rowvals_arg.request().size; + auto p_jac_times_cjmass_rowvals = jac_times_cjmass_rowvals_arg.unchecked<1>(); + jac_times_cjmass_rowvals.resize(n_row_vals); + for (int i = 0; i < n_row_vals; i++) { + jac_times_cjmass_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const int n_col_ptrs = jac_times_cjmass_colptrs_arg.request().size; + auto p_jac_times_cjmass_colptrs = jac_times_cjmass_colptrs_arg.unchecked<1>(); + jac_times_cjmass_colptrs.resize(n_col_ptrs); + for (int i = 0; i < n_col_ptrs; i++) { + jac_times_cjmass_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + + inputs.resize(inputs_length); + } + + IREEFunction rhs_alg_iree; + IREEFunction jac_times_cjmass_iree; + IREEFunction jac_action_iree; + IREEFunction mass_action_iree; + IREEFunction sens_iree; + IREEFunction events_iree; + + std::vector var_fcns_iree; + std::vector dvar_dy_fcns_iree; + std::vector dvar_dp_fcns_iree; + + realtype* get_tmp_state_vector() override { + return tmp_state_vector.data(); + } + realtype* get_tmp_sparse_jacobian_data() override { + return tmp_sparse_jacobian_data.data(); + } + + ~IREEFunctions() { + // cleanup IREE + iree_compiler->cleanup(); + } +}; + +#endif // PYBAMM_IDAKLU_IREE_FUNCTIONS_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp new file mode 100644 index 0000000..d1c5575 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.cpp @@ -0,0 +1,91 @@ +#include "ModuleParser.hpp" + +ModuleParser::ModuleParser(const std::string& mlir) : mlir(mlir) +{ + parse(); +} + +void ModuleParser::parse() +{ + // Parse module name + std::regex module_name_regex("module @([^\\s]+)"); // Match until first whitespace + std::smatch module_name_match; + std::regex_search(this->mlir, module_name_match, module_name_regex); + if (module_name_match.size() == 0) { + std::cerr << "Could not find module name in module" << std::endl; + std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; + throw std::runtime_error("Could not find module name in module"); + } + module_name = module_name_match[1].str(); + DEBUG("Module name: " << module_name); + + // Assign function name + function_name = module_name + ".main"; + + // Isolate 'main' function call signature + std::regex main_func("public @main\\((.*?)\\) -> \\((.*?)\\)"); + std::smatch match; + std::regex_search(this->mlir, match, main_func); + if (match.size() == 0) { + std::cerr << "Could not find 'main' function in module" << std::endl; + std::cerr << "Module snippet: " << this->mlir.substr(0, 1000) << std::endl; + throw std::runtime_error("Could not find 'main' function in module"); + } + std::string main_sig_inputs = match[1].str(); + std::string main_sig_outputs = match[2].str(); + DEBUG( + "Main function signature: " << main_sig_inputs << " -> " << main_sig_outputs << '\n' + ); + + // Parse input sizes + input_shape.clear(); + std::regex input_size("tensor<(.*?)>"); + for(std::sregex_iterator i = std::sregex_iterator(main_sig_inputs.begin(), main_sig_inputs.end(), input_size); + i != std::sregex_iterator(); + ++i) + { + std::smatch matchi = *i; + std::string match_str = matchi.str(); + std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string + std::vector shape; + std::string dim_str; + for (char c : shape_str) { + if (c == 'x') { + shape.push_back(std::stoi(dim_str)); + dim_str = ""; + } else { + dim_str += c; + } + } + input_shape.push_back(shape); + } + + // Parse output sizes + output_shape.clear(); + std::regex output_size("tensor<(.*?)>"); + for( + std::sregex_iterator i = std::sregex_iterator(main_sig_outputs.begin(), main_sig_outputs.end(), output_size); + i != std::sregex_iterator(); + ++i + ) { + std::smatch matchi = *i; + std::string match_str = matchi.str(); + std::string shape_str = match_str.substr(7, match_str.size() - 8); // Remove 'tensor<>' from string + std::vector shape; + std::string dim_str; + for (char c : shape_str) { + if (c == 'x') { + shape.push_back(std::stoi(dim_str)); + dim_str = ""; + } else { + dim_str += c; + } + } + // If shape is empty, assume scalar (i.e. "tensor" or some singleton variant) + if (shape.size() == 0) { + shape.push_back(1); + } + // Add output to list + output_shape.push_back(shape); + } +} diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp new file mode 100644 index 0000000..2fbfdc0 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/IREE/ModuleParser.hpp @@ -0,0 +1,55 @@ +#ifndef PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP +#define PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP + +#include +#include +#include +#include +#include + +#include "../../common.hpp" + +class ModuleParser { +private: + std::string mlir; // cppcheck-suppress unusedStructMember + // codacy fix: member is referenced as this->mlir in parse() + std::string module_name; + std::string function_name; + std::vector> input_shape; + std::vector> output_shape; +public: + /** + * @brief Constructor + * @param mlir: string representation of MLIR code for the module + */ + explicit ModuleParser(const std::string& mlir); + + /** + * @brief Get the module name + * @return module name + */ + const std::string& getModuleName() const { return module_name; } + + /** + * @brief Get the function name + * @return function name + */ + const std::string& getFunctionName() const { return function_name; } + + /** + * @brief Get the input shape + * @return input shape + */ + const std::vector>& getInputShape() const { return input_shape; } + + /** + * @brief Get the output shape + * @return output shape + */ + const std::vector>& getOutputShape() const { return output_shape; } + +private: + void parse(); +}; + +#endif // PYBAMM_IDAKLU_IREE_MODULE_PARSER_HPP diff --git a/src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp b/src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp new file mode 100644 index 0000000..c84c392 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Expressions/IREE/iree_jit.cpp @@ -0,0 +1,408 @@ +#include "iree_jit.hpp" +#include "iree/hal/buffer_view.h" +#include "iree/hal/buffer_view_util.h" +#include "../../common.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +// Used to suppress stderr output (see initIREE below) +#ifdef _WIN32 +#include +#define close _close +#define dup _dup +#define fileno _fileno +#define open _open +#define dup2 _dup2 +#define NULL_DEVICE "NUL" +#else +#define NULL_DEVICE "/dev/null" +#endif + +void IREESession::handle_compiler_error(iree_compiler_error_t *error) { + const char *msg = ireeCompilerErrorGetMessage(error); + fprintf(stderr, "Error from compiler API:\n%s\n", msg); + ireeCompilerErrorDestroy(error); +} + +void IREESession::cleanup_compiler_state(compiler_state_t s) { + if (s.inv) + ireeCompilerInvocationDestroy(s.inv); + if (s.output) + ireeCompilerOutputDestroy(s.output); + if (s.source) + ireeCompilerSourceDestroy(s.source); + if (s.session) + ireeCompilerSessionDestroy(s.session); +} + +IREECompiler::IREECompiler() { + this->device_uri = "local-sync"; +}; + +IREECompiler::~IREECompiler() { + ireeCompilerGlobalShutdown(); +}; + +int IREECompiler::init(int argc, const char **argv) { + return initIREE(argc, argv); // Initialisation and version checking +}; + +int IREECompiler::cleanup() { + return 0; +}; + +IREESession::IREESession() { + s.session = NULL; + s.source = NULL; + s.output = NULL; + s.inv = NULL; +}; + +IREESession::IREESession(const char *device_uri, const std::string& mlir_code) : IREESession() { + this->device_uri=device_uri; + this->mlir_code=mlir_code; + init(); +} + +int IREESession::init() { + if (initCompiler() != 0) // Prepare compiler inputs and outputs + return 1; + if (initCompileToByteCode() != 0) // Compile to bytecode + return 1; + if (initRuntime() != 0) // Initialise runtime environment + return 1; + return 0; +}; + +int IREECompiler::initIREE(int argc, const char **argv) { + + if (device_uri == NULL) { + DEBUG("No device URI provided, using local-sync\n"); + this->device_uri = "local-sync"; + } + + int cl_argc = argc; + const char *iree_compiler_lib = std::getenv("IREE_COMPILER_LIB"); + + // Load the compiler library and initialize it + // NOTE: On second and subsequent calls, the function will return false and display + // a message on stderr, but it is safe to ignore this message. For an improved user + // experience we actively suppress stderr during the call to this function but since + // this also suppresses any other error message, we actively check for the presence + // of the library file prior to the call. + + // Check if the library file exists + if (iree_compiler_lib == NULL) { + fprintf(stderr, "Error: IREE_COMPILER_LIB environment variable not set\n"); + return 1; + } + if (access(iree_compiler_lib, F_OK) == -1) { + fprintf(stderr, "Error: IREE_COMPILER_LIB file not found\n"); + return 1; + } + // Suppress stderr + int saved_stderr = dup(fileno(stderr)); + if (!freopen(NULL_DEVICE, "w", stderr)) + DEBUG("Error: failed redirecting stderr"); + // Load library + bool result = ireeCompilerLoadLibrary(iree_compiler_lib); + // Restore stderr + fflush(stderr); + dup2(saved_stderr, fileno(stderr)); + close(saved_stderr); + // Process result + if (!result) { + // Library may have already been loaded (can be safely ignored), + // or may not be found (critical error), we cannot tell which from the return value. + return 1; + } + // Must be balanced with a call to ireeCompilerGlobalShutdown() + ireeCompilerGlobalInitialize(); + + // To set global options (see `iree-compile --help` for possibilities), use + // |ireeCompilerGetProcessCLArgs| and |ireeCompilerSetupGlobalCL| + ireeCompilerGetProcessCLArgs(&cl_argc, &argv); + ireeCompilerSetupGlobalCL(cl_argc, argv, "iree-jit", false); + + // Check the API version before proceeding any further + uint32_t api_version = (uint32_t)ireeCompilerGetAPIVersion(); + uint16_t api_version_major = (uint16_t)((api_version >> 16) & 0xFFFFUL); + uint16_t api_version_minor = (uint16_t)(api_version & 0xFFFFUL); + DEBUG("Compiler API version: " << api_version_major << "." << api_version_minor); + if (api_version_major > IREE_COMPILER_EXPECTED_API_MAJOR || + api_version_minor < IREE_COMPILER_EXPECTED_API_MINOR) { + fprintf(stderr, + "Error: incompatible API version; built for version %" PRIu16 + ".%" PRIu16 " but loaded version %" PRIu16 ".%" PRIu16 "\n", + IREE_COMPILER_EXPECTED_API_MAJOR, IREE_COMPILER_EXPECTED_API_MINOR, + api_version_major, api_version_minor); + ireeCompilerGlobalShutdown(); + return 1; + } + + // Check for a build tag with release version information + const char *revision = ireeCompilerGetRevision(); // cppcheck-suppress unreadVariable + DEBUG("Compiler revision: '" << revision << "'"); + return 0; +}; + +int IREESession::initCompiler() { + + // A session provides a scope where one or more invocations can be executed + s.session = ireeCompilerSessionCreate(); + + // Read the MLIR from memory + error = ireeCompilerSourceWrapBuffer( + s.session, + "expr_buffer", // name of the buffer (does not need to match MLIR) + mlir_code.c_str(), + mlir_code.length() + 1, + true, + &s.source + ); + if (error) { + fprintf(stderr, "Error wrapping source buffer\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + DEBUG("Wrapped buffer as a compiler source"); + + return 0; +}; + +int IREESession::initCompileToByteCode() { + // Use an invocation to compile from the input source to the output stream + iree_compiler_invocation_t *inv = ireeCompilerInvocationCreate(s.session); + ireeCompilerInvocationEnableConsoleDiagnostics(inv); + + if (!ireeCompilerInvocationParseSource(inv, s.source)) { + fprintf(stderr, "Error parsing input source into invocation\n"); + cleanup_compiler_state(s); + return 1; + } + + // Compile, specifying the target dialect phase + ireeCompilerInvocationSetCompileToPhase(inv, "end"); + + // Run the compiler invocation pipeline + if (!ireeCompilerInvocationPipeline(inv, IREE_COMPILER_PIPELINE_STD)) { + fprintf(stderr, "Error running compiler invocation\n"); + cleanup_compiler_state(s); + return 1; + } + DEBUG("Compilation successful"); + + // Create compiler 'output' to a memory buffer + error = ireeCompilerOutputOpenMembuffer(&s.output); + if (error) { + fprintf(stderr, "Error opening output membuffer\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + + // Create bytecode in memory + error = ireeCompilerInvocationOutputVMBytecode(inv, s.output); + if (error) { + fprintf(stderr, "Error creating VM bytecode\n"); + handle_compiler_error(error); + cleanup_compiler_state(s); + return 1; + } + + // Once the bytecode has been written, retrieve the memory map + ireeCompilerOutputMapMemory(s.output, &contents, &size); + + return 0; +}; + +int IREESession::initRuntime() { + // Setup the shared runtime instance + iree_runtime_instance_options_t instance_options; + iree_runtime_instance_options_initialize(&instance_options); + iree_runtime_instance_options_use_all_available_drivers(&instance_options); + status = iree_runtime_instance_create( + &instance_options, iree_allocator_system(), &instance); + + // Create the HAL device used to run the workloads + if (iree_status_is_ok(status)) { + status = iree_hal_create_device( + iree_runtime_instance_driver_registry(instance), + iree_make_cstring_view(device_uri), + iree_runtime_instance_host_allocator(instance), &device); + } + + // Set up the session to run the module + if (iree_status_is_ok(status)) { + iree_runtime_session_options_t session_options; + iree_runtime_session_options_initialize(&session_options); + status = iree_runtime_session_create_with_device( + instance, &session_options, device, + iree_runtime_instance_host_allocator(instance), &session); + } + + // Load the compiled user module from a file + if (iree_status_is_ok(status)) { + /*status = iree_runtime_session_append_bytecode_module_from_file(session, module_path);*/ + status = iree_runtime_session_append_bytecode_module_from_memory( + session, + iree_make_const_byte_span(contents, size), + iree_allocator_null()); + } + + if (!iree_status_is_ok(status)) + return 1; + + return 0; +}; + +// Release the session and free all cached resources. +int IREESession::cleanup() { + iree_runtime_session_release(session); + iree_hal_device_release(device); + iree_runtime_instance_release(instance); + + int ret = (int)iree_status_code(status); + if (!iree_status_is_ok(status)) { + iree_status_fprint(stderr, status); + iree_status_ignore(status); + } + cleanup_compiler_state(s); + return ret; +} + +iree_status_t IREESession::iree_runtime_exec( + const std::string& function_name, + const std::vector>& inputs, + const std::vector>& data, + std::vector>& result +) { + + // Initialize the call to the function. + status = iree_runtime_call_initialize_by_name( + session, iree_make_cstring_view(function_name.c_str()), &call); + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_initialize_by_name failed" << std::endl; + iree_status_fprint(stderr, status); + return status; + } + + // Append the function inputs with the HAL device allocator in use by the + // session. The buffers will be usable within the session and _may_ be usable + // in other sessions depending on whether they share a compatible device. + iree_hal_allocator_t* device_allocator = + iree_runtime_session_device_allocator(session); + host_allocator = iree_runtime_session_host_allocator(session); + status = iree_ok_status(); + if (iree_status_is_ok(status)) { + + for(int k=0; k arg_shape(input_shape.size()); + for (int i = 0; i < input_shape.size(); i++) { + arg_shape[i] = input_shape[i]; + } + int numel = 1; + for(int i = 0; i < input_shape.size(); i++) { + numel *= input_shape[i]; + } + std::vector arg_data(numel); + for(int i = 0; i < numel; i++) { + arg_data[i] = input_data[i]; + } + + status = iree_hal_buffer_view_allocate_buffer_copy( + device, device_allocator, + // Shape rank and dimensions: + arg_shape.size(), arg_shape.data(), + // Element type: + IREE_HAL_ELEMENT_TYPE_FLOAT_32, + // Encoding type: + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + (iree_hal_buffer_params_t){ + // Intended usage of the buffer (transfers, dispatches, etc): + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + // Access to allow to this memory: + .access = IREE_HAL_MEMORY_ACCESS_ALL, + // Where to allocate (host or device): + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }, + // The actual heap buffer to wrap or clone and its allocator: + iree_make_const_byte_span(&arg_data[0], sizeof(float) * arg_data.size()), + // Buffer view + storage are returned and owned by the caller: + &arg); + } + if (iree_status_is_ok(status)) { + // Add to the call inputs list (which retains the buffer view). + status = iree_runtime_call_inputs_push_back_buffer_view(&call, arg); + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_inputs_push_back_buffer_view failed" << std::endl; + iree_status_fprint(stderr, status); + } + } + // Since the call retains the buffer view we can release it here. + iree_hal_buffer_view_release(arg); + } + } + + // Synchronously perform the call. + if (iree_status_is_ok(status)) { + status = iree_runtime_call_invoke(&call, /*flags=*/0); + } + if (!iree_status_is_ok(status)) { + std::cerr << "Error: iree_runtime_call_invoke failed" << std::endl; + iree_status_fprint(stderr, status); + } + + for(int k=0; k +#include +#include +#include + +#include +#include +#include + +#define IREE_COMPILER_EXPECTED_API_MAJOR 1 // At most this major version +#define IREE_COMPILER_EXPECTED_API_MINOR 2 // At least this minor version + +// Forward declaration +class IREESession; + +/* + * @brief IREECompiler class + * @details This class is used to compile MLIR code to IREE bytecode and + * create IREE sessions. + */ +class IREECompiler { +private: + /* + * @brief Device Uniform Resource Identifier (URI) + * @details The device URI is used to specify the device to be used by the + * IREE runtime. E.g. "local-sync" for CPU, "vulkan" for GPU, etc. + */ + const char *device_uri = NULL; + +private: + /* + * @brief Initialize the IREE runtime + */ + int initIREE(int argc, const char **argv); + +public: + /* + * @brief Default constructor + */ + IREECompiler(); + + /* + * @brief Destructor + */ + ~IREECompiler(); + + /* + * @brief Constructor with device URI + * @param device_uri Device URI + */ + explicit IREECompiler(const char *device_uri) + : IREECompiler() { this->device_uri=device_uri; } + + /* + * @brief Initialize the compiler + */ + int init(int argc, const char **argv); + + /* + * @brief Cleanup the compiler + * @details This method cleans up the compiler and all the IREE sessions + * created by the compiler. Returns 0 on success. + */ + int cleanup(); +}; + +/* + * @brief Compiler state + */ +typedef struct compiler_state_t { + iree_compiler_session_t *session; // cppcheck-suppress unusedStructMember + iree_compiler_source_t *source; // cppcheck-suppress unusedStructMember + iree_compiler_output_t *output; // cppcheck-suppress unusedStructMember + iree_compiler_invocation_t *inv; // cppcheck-suppress unusedStructMember +} compiler_state_t; + +/* + * @brief IREE session class + */ +class IREESession { +private: // data members + const char *device_uri = NULL; + compiler_state_t s; + iree_compiler_error_t *error = NULL; + void *contents = NULL; + uint64_t size = 0; + iree_runtime_session_t* session = NULL; + iree_status_t status; + iree_hal_device_t* device = NULL; + iree_runtime_instance_t* instance = NULL; + std::string mlir_code; // cppcheck-suppress unusedStructMember + iree_runtime_call_t call; + iree_allocator_t host_allocator; + +private: // private methods + void handle_compiler_error(iree_compiler_error_t *error); + void cleanup_compiler_state(compiler_state_t s); + int init(); + int initCompiler(); + int initCompileToByteCode(); + int initRuntime(); + +public: // public methods + + /* + * @brief Default constructor + */ + IREESession(); + + /* + * @brief Constructor with device URI and MLIR code + * @param device_uri Device URI + * @param mlir_code MLIR code + */ + explicit IREESession(const char *device_uri, const std::string& mlir_code); + + /* + * @brief Cleanup the IREE session + */ + int cleanup(); + + /* + * @brief Execute the pre-compiled byte-code with the given inputs + * @param function_name Function name to execute + * @param inputs List of input shapes + * @param data List of input data + * @param result List of output data + */ + iree_status_t iree_runtime_exec( + const std::string& function_name, + const std::vector>& inputs, + const std::vector>& data, + std::vector>& result + ); +}; + +#endif // IREE_JIT_HPP diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolver.cpp b/src/pybammsolvers/idaklu_source/IDAKLUSolver.cpp new file mode 100644 index 0000000..b769d4d --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolver.cpp @@ -0,0 +1 @@ +#include "IDAKLUSolver.hpp" diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp b/src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp new file mode 100644 index 0000000..379d647 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolver.hpp @@ -0,0 +1,48 @@ +#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_HPP +#define PYBAMM_IDAKLU_CASADI_SOLVER_HPP + +#include "common.hpp" +#include "SolutionData.hpp" + + +/** + * Abstract base class for solutions that can use different solvers and vector + * implementations. + * @brief An abstract base class for the Idaklu solver + */ +class IDAKLUSolver +{ +public: + + /** + * @brief Default constructor + */ + IDAKLUSolver() = default; + + /** + * @brief Default destructor + */ + ~IDAKLUSolver() = default; + + /** + * @brief Abstract solver method that executes the solver + */ + virtual SolutionData solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps + ) = 0; + + /** + * Abstract method to initialize the solver, once vectors and solver classes + * are set + * @brief Abstract initialization method + */ + virtual void Initialize() = 0; +}; + +#endif // PYBAMM_IDAKLU_CASADI_SOLVER_HPP diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.cpp b/src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.cpp new file mode 100644 index 0000000..8a76d73 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.cpp @@ -0,0 +1,145 @@ +#include "IDAKLUSolverGroup.hpp" +#include +#include + +std::vector IDAKLUSolverGroup::solve( + np_array t_eval_np, + np_array t_interp_np, + np_array y0_np, + np_array yp0_np, + np_array inputs) { + DEBUG("IDAKLUSolverGroup::solve"); + + // If t_interp is empty, save all adaptive steps + bool save_adaptive_steps = t_interp_np.size() == 0; + + const realtype* t_eval_begin = t_eval_np.data(); + const realtype* t_eval_end = t_eval_begin + t_eval_np.size(); + const realtype* t_interp_begin = t_interp_np.data(); + const realtype* t_interp_end = t_interp_begin + t_interp_np.size(); + + // Process the time inputs + // 1. Get the sorted and unique t_eval vector + auto const t_eval = makeSortedUnique(t_eval_begin, t_eval_end); + + // 2.1. Get the sorted and unique t_interp vector + auto const t_interp_unique_sorted = makeSortedUnique(t_interp_begin, t_interp_end); + + // 2.2 Remove the t_eval values from t_interp + auto const t_interp_setdiff = setDiff(t_interp_unique_sorted.begin(), t_interp_unique_sorted.end(), t_eval_begin, t_eval_end); + + // 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed + auto const t_interp = makeSortedUnique(t_interp_setdiff.begin(), t_interp_setdiff.end()); + + int const number_of_evals = t_eval.size(); + int const number_of_interps = t_interp.size(); + + // setDiff removes entries of t_interp that overlap with + // t_eval, so we need to check if we need to interpolate any unique points. + // This is not the same as save_adaptive_steps since some entries of t_interp + // may be removed by setDiff + bool save_interp_steps = number_of_interps > 0; + + // 3. Check if the timestepping entries are valid + if (number_of_evals < 2) { + throw std::invalid_argument( + "t_eval must have at least 2 entries" + ); + } else if (save_interp_steps) { + if (t_interp.front() < t_eval.front()) { + throw std::invalid_argument( + "t_interp values must be greater than the smallest t_eval value: " + + std::to_string(t_eval.front()) + ); + } else if (t_interp.back() > t_eval.back()) { + throw std::invalid_argument( + "t_interp values must be less than the greatest t_eval value: " + + std::to_string(t_eval.back()) + ); + } + } + + auto n_coeffs = number_of_states + number_of_parameters * number_of_states; + + // check y0 and yp0 and inputs have the correct dimensions + if (y0_np.ndim() != 2) + throw std::domain_error("y0 has wrong number of dimensions. Expected 2 but got " + std::to_string(y0_np.ndim())); + if (yp0_np.ndim() != 2) + throw std::domain_error("yp0 has wrong number of dimensions. Expected 2 but got " + std::to_string(yp0_np.ndim())); + if (inputs.ndim() != 2) + throw std::domain_error("inputs has wrong number of dimensions. Expected 2 but got " + std::to_string(inputs.ndim())); + + auto number_of_groups = y0_np.shape()[0]; + + // check y0 and yp0 and inputs have the correct shape + if (y0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "y0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(y0_np.shape()[1])); + + if (yp0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "yp0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(yp0_np.shape()[1])); + + if (yp0_np.shape()[0] != number_of_groups) + throw std::domain_error( + "yp0 has wrong number of rows. Expected " + std::to_string(number_of_groups) + + " but got " + std::to_string(yp0_np.shape()[0])); + + if (inputs.shape()[0] != number_of_groups) + throw std::domain_error( + "inputs has wrong number of rows. Expected " + std::to_string(number_of_groups) + + " but got " + std::to_string(inputs.shape()[0])); + + const std::size_t solves_per_thread = number_of_groups / m_solvers.size(); + const std::size_t remainder_solves = number_of_groups % m_solvers.size(); + + const realtype *y0 = y0_np.data(); + const realtype *yp0 = yp0_np.data(); + const realtype *inputs_data = inputs.data(); + + std::vector results(number_of_groups); + + std::optional exception; + + omp_set_num_threads(m_solvers.size()); + #pragma omp parallel for + for (int i = 0; i < m_solvers.size(); i++) { + try { + for (int j = 0; j < solves_per_thread; j++) { + const std::size_t index = i * solves_per_thread + j; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps); + } + } catch (std::exception &e) { + // If an exception is thrown, we need to catch it and rethrow it outside the parallel region + #pragma omp critical + { + exception = e; + } + } + } + + if (exception.has_value()) { + py::set_error(PyExc_ValueError, exception->what()); + throw py::error_already_set(); + } + + for (int i = 0; i < remainder_solves; i++) { + const std::size_t index = number_of_groups - remainder_solves + i; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps); + } + + // create solutions (needs to be serial as we're using the Python GIL) + std::vector solutions(number_of_groups); + for (int i = 0; i < number_of_groups; i++) { + solutions[i] = results[i].generate_solution(); + } + return solutions; +} diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.hpp b/src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.hpp new file mode 100644 index 0000000..609b3b6 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolverGroup.hpp @@ -0,0 +1,48 @@ +#ifndef PYBAMM_IDAKLU_SOLVER_GROUP_HPP +#define PYBAMM_IDAKLU_SOLVER_GROUP_HPP + +#include "IDAKLUSolver.hpp" +#include "common.hpp" + +/** + * @brief class for a group of solvers. + */ +class IDAKLUSolverGroup +{ +public: + + /** + * @brief Default constructor + */ + IDAKLUSolverGroup(std::vector> solvers, int number_of_states, int number_of_parameters): + m_solvers(std::move(solvers)), + number_of_states(number_of_states), + number_of_parameters(number_of_parameters) + {} + + // no copy constructor (unique_ptr cannot be copied) + IDAKLUSolverGroup(IDAKLUSolverGroup &) = delete; + + /** + * @brief Default destructor + */ + ~IDAKLUSolverGroup() = default; + + /** + * @brief solver method that returns a vector of Solutions + */ + std::vector solve( + np_array t_eval_np, + np_array t_interp_np, + np_array y0_np, + np_array yp0_np, + np_array inputs); + + + private: + std::vector> m_solvers; + int number_of_states; + int number_of_parameters; +}; + +#endif // PYBAMM_IDAKLU_SOLVER_GROUP_HPP diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.hpp b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.hpp new file mode 100644 index 0000000..ee2c03a --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.hpp @@ -0,0 +1,299 @@ +#ifndef PYBAMM_IDAKLU_SOLVEROPENMP_HPP +#define PYBAMM_IDAKLU_SOLVEROPENMP_HPP + +#include "IDAKLUSolver.hpp" +#include "common.hpp" +#include +using std::vector; + +#include "Options.hpp" +#include "Solution.hpp" +#include "sundials_legacy_wrapper.hpp" + +/** + * @brief Abstract solver class based on OpenMP vectors + * + * An abstract class that implements a solution based on OpenMP + * vectors but needs to be provided with a suitable linear solver. + * + * This class broadly implements the following skeleton workflow: + * (https://sundials.readthedocs.io/en/latest/ida/Usage/index.html) + * 1. (N/A) Initialize parallel or multi-threaded environment + * 2. Create the SUNDIALS context object + * 3. Create the vector of initial values + * 4. Create matrix object (if appropriate) + * 5. Create linear solver object + * 6. (N/A) Create nonlinear solver object + * 7. Create IDA object + * 8. Initialize IDA solver + * 9. Specify integration tolerances + * 10. Attach the linear solver + * 11. Set linear solver optional inputs + * 12. (N/A) Attach nonlinear solver module + * 13. (N/A) Set nonlinear solver optional inputs + * 14. Specify rootfinding problem (optional) + * 15. Set optional inputs + * 16. Correct initial values (optional) + * 17. Advance solution in time + * 18. Get optional outputs + * 19. Destroy objects + * 20. (N/A) Finalize MPI + */ +template +class IDAKLUSolverOpenMP : public IDAKLUSolver +{ + // NB: cppcheck-suppress unusedStructMember is used because codacy reports + // these members as unused even though they are important in child + // classes, but are passed by variadic arguments (and are therefore unnamed) +public: + void *ida_mem = nullptr; + np_array atol_np; + np_array rhs_alg_id; + int const number_of_states; // cppcheck-suppress unusedStructMember + int const number_of_parameters; // cppcheck-suppress unusedStructMember + int const number_of_events; // cppcheck-suppress unusedStructMember + int number_of_timesteps; + int precon_type; // cppcheck-suppress unusedStructMember + N_Vector yy, yyp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance + N_Vector *yyS; // cppcheck-suppress unusedStructMember + N_Vector *yypS; // cppcheck-suppress unusedStructMember + N_Vector id; // rhs_alg_id + realtype rtol; + int const jac_times_cjmass_nnz; // cppcheck-suppress unusedStructMember + int const jac_bandwidth_lower; // cppcheck-suppress unusedStructMember + int const jac_bandwidth_upper; // cppcheck-suppress unusedStructMember + SUNMatrix J; + SUNLinearSolver LS = nullptr; + std::unique_ptr functions; + vector res; + vector res_dvar_dy; + vector res_dvar_dp; + bool const sensitivity; // cppcheck-suppress unusedStructMember + bool const save_outputs_only; // cppcheck-suppress unusedStructMember + bool save_hermite; // cppcheck-suppress unusedStructMember + bool is_ODE; // cppcheck-suppress unusedStructMember + int length_of_return_vector; // cppcheck-suppress unusedStructMember + vector t; // cppcheck-suppress unusedStructMember + vector> y; // cppcheck-suppress unusedStructMember + vector> yp; // cppcheck-suppress unusedStructMember + vector>> yS; // cppcheck-suppress unusedStructMember + vector>> ypS; // cppcheck-suppress unusedStructMember + SetupOptions const setup_opts; + SolverOptions const solver_opts; + +#if SUNDIALS_VERSION_MAJOR >= 6 + SUNContext sunctx; +#endif + +public: + /** + * @brief Constructor + */ + IDAKLUSolverOpenMP( + np_array atol_np, + double rel_tol, + np_array rhs_alg_id, + int number_of_parameters, + int number_of_events, + int jac_times_cjmass_nnz, + int jac_bandwidth_lower, + int jac_bandwidth_upper, + std::unique_ptr functions, + const SetupOptions &setup_opts, + const SolverOptions &solver_opts + ); + + /** + * @brief Destructor + */ + ~IDAKLUSolverOpenMP(); + + /** + * @brief The main solve method that solves for each variable and time step + */ + SolutionData solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps + ) override; + + + /** + * @brief Concrete implementation of initialization method + */ + void Initialize() override; + + /** + * @brief Allocate memory for OpenMP vectors + */ + void AllocateVectors(); + + /** + * @brief Allocate memory for matrices (noting appropriate matrix format/types) + */ + void SetMatrix(); + + /** + * @brief Get the length of the return vector + */ + int ReturnVectorLength(); + + /** + * @brief Initialize the storage for the solution + */ + void InitializeStorage(int const N); + + /** + * @brief Initialize the storage for Hermite interpolation + */ + void InitializeHermiteStorage(int const N); + + /** + * @brief Apply user-configurable IDA options + */ + void SetSolverOptions(); + + /** + * @brief Check the return flag for errors + */ + void CheckErrors(int const & flag); + + /** + * @brief Print the solver statistics + */ + void PrintStats(); + + /** + * @brief Set a consistent initialization for ODEs + */ + void ReinitializeIntegrator(const realtype& t_val); + + /** + * @brief Set a consistent initialization for the system of equations + */ + void ConsistentInitialization( + const realtype& t_val, + const realtype& t_next, + const int& icopt); + + /** + * @brief Set a consistent initialization for DAEs + */ + void ConsistentInitializationDAE( + const realtype& t_val, + const realtype& t_next, + const int& icopt); + + /** + * @brief Set a consistent initialization for ODEs + */ + void ConsistentInitializationODE(const realtype& t_val); + + /** + * @brief Extend the adaptive arrays by 1 + */ + void ExtendAdaptiveArrays(); + + /** + * @brief Extend the Hermite interpolation info by 1 + */ + void ExtendHermiteArrays(); + + /** + * @brief Set the step values + */ + void SetStep( + realtype &tval, + realtype *y_val, + realtype *yp_val, + vector const &yS_val, + vector const &ypS_val, + int &i_save + ); + + /** + * @brief Save the interpolated step values + */ + void SetStepInterp( + int &i_interp, + realtype &t_interp_next, + vector const &t_interp, + realtype &t_val, + realtype &t_prev, + realtype const &t_next, + realtype *y_val, + realtype *yp_val, + vector const &yS_val, + vector const &ypS_val, + int &i_save + ); + + /** + * @brief Save y and yS at the current time + */ + void SetStepFull( + realtype &t_val, + realtype *y_val, + vector const &yS_val, + int &i_save + ); + + /** + * @brief Save yS at the current time + */ + void SetStepFullSensitivities( + realtype &t_val, + realtype *y_val, + vector const &yS_val, + int &i_save + ); + + /** + * @brief Save the output function results at the requested time + */ + void SetStepOutput( + realtype &t_val, + realtype *y_val, + const vector &yS_val, + int &i_save + ); + + /** + * @brief Save the output function sensitivities at the requested time + */ + void SetStepOutputSensitivities( + realtype &t_val, + realtype *y_val, + const vector &yS_val, + int &i_save + ); + + /** + * @brief Save the output function results at the requested time + */ + void SetStepHermite( + realtype &t_val, + realtype *yp_val, + const vector &ypS_val, + int &i_save + ); + + /** + * @brief Save the output function sensitivities at the requested time + */ + void SetStepHermiteSensitivities( + realtype &t_val, + realtype *yp_val, + const vector &ypS_val, + int &i_save + ); + +}; + +#include "IDAKLUSolverOpenMP.inl" + +#endif // PYBAMM_IDAKLU_SOLVEROPENMP_HPP diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl new file mode 100644 index 0000000..d128ae1 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP.inl @@ -0,0 +1,977 @@ +#include "Expressions/Expressions.hpp" +#include "sundials_functions.hpp" +#include +#include "common.hpp" +#include "SolutionData.hpp" + +template +IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( + np_array atol_np_input, + double rel_tol, + np_array rhs_alg_id_input, + int number_of_parameters_input, + int number_of_events_input, + int jac_times_cjmass_nnz_input, + int jac_bandwidth_lower_input, + int jac_bandwidth_upper_input, + std::unique_ptr functions_arg, + const SetupOptions &setup_input, + const SolverOptions &solver_input +) : + atol_np(atol_np_input), + rhs_alg_id(rhs_alg_id_input), + number_of_states(atol_np_input.request().size), + number_of_parameters(number_of_parameters_input), + number_of_events(number_of_events_input), + jac_times_cjmass_nnz(jac_times_cjmass_nnz_input), + jac_bandwidth_lower(jac_bandwidth_lower_input), + jac_bandwidth_upper(jac_bandwidth_upper_input), + functions(std::move(functions_arg)), + sensitivity(number_of_parameters > 0), + save_outputs_only(functions->var_fcns.size() > 0), + setup_opts(setup_input), + solver_opts(solver_input) +{ + // Construction code moved to Initialize() which is called from the + // (child) IDAKLUSolver_* class constructors. + DEBUG("IDAKLUSolverOpenMP:IDAKLUSolverOpenMP"); + auto atol = atol_np.unchecked<1>(); + + // create SUNDIALS context object + SUNContext_Create(NULL, &sunctx); // calls null-wrapper if Sundials Ver<6 + + // allocate memory for solver + ida_mem = IDACreate(sunctx); + + // create the vector of initial values + AllocateVectors(); + if (sensitivity) { + yyS = N_VCloneVectorArray(number_of_parameters, yy); + yypS = N_VCloneVectorArray(number_of_parameters, yyp); + } + // set initial values + realtype *atval = N_VGetArrayPointer(avtol); + for (int i = 0; i < number_of_states; i++) { + atval[i] = atol[i]; + } + + for (int is = 0; is < number_of_parameters; is++) { + N_VConst(RCONST(0.0), yyS[is]); + N_VConst(RCONST(0.0), yypS[is]); + } + + // create Matrix objects + SetMatrix(); + + // initialise solver + IDAInit(ida_mem, residual_eval, 0, yy, yyp); + + // set tolerances + rtol = RCONST(rel_tol); + IDASVtolerances(ida_mem, rtol, avtol); + + // Set events + IDARootInit(ida_mem, number_of_events, events_eval); + + // Set user data + void *user_data = functions.get(); + IDASetUserData(ida_mem, user_data); + + // Specify preconditioner type + precon_type = SUN_PREC_NONE; + if (this->setup_opts.preconditioner != "none") { + precon_type = SUN_PREC_LEFT; + } + + // The default is to solve a DAE for generality. This may be changed + // to an ODE during the Initialize() call + is_ODE = false; + + // Will be overwritten during the solve() call + save_hermite = solver_opts.hermite_interpolation; +} + +template +void IDAKLUSolverOpenMP::AllocateVectors() { + DEBUG("IDAKLUSolverOpenMP::AllocateVectors (num_threads = " << setup_opts.num_threads << ")"); + // Create vectors + if (setup_opts.num_threads == 1) { + yy = N_VNew_Serial(number_of_states, sunctx); + yyp = N_VNew_Serial(number_of_states, sunctx); + y_cache = N_VNew_Serial(number_of_states, sunctx); + avtol = N_VNew_Serial(number_of_states, sunctx); + id = N_VNew_Serial(number_of_states, sunctx); + } else { + DEBUG("IDAKLUSolverOpenMP::AllocateVectors OpenMP"); + yy = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + yyp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + y_cache = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + avtol = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + } +} + +template +void IDAKLUSolverOpenMP::InitializeStorage(int const N) { + length_of_return_vector = ReturnVectorLength(); + + t = vector(N, 0.0); + + y = vector>( + N, + vector(length_of_return_vector, 0.0) + ); + + yS = vector>>( + N, + vector>( + number_of_parameters, + vector(length_of_return_vector, 0.0) + ) + ); + + if (save_hermite) { + InitializeHermiteStorage(N); + } +} + +template +void IDAKLUSolverOpenMP::InitializeHermiteStorage(int const N) { + yp = vector>( + N, + vector(number_of_states, 0.0) + ); + + ypS = vector>>( + N, + vector>( + number_of_parameters, + vector(number_of_states, 0.0) + ) + ); +} + +template +int IDAKLUSolverOpenMP::ReturnVectorLength() { + if (!save_outputs_only) { + return number_of_states; + } + + // set return vectors + int length_of_return_vector = 0; + size_t max_res_size = 0; // maximum result size (for common result buffer) + size_t max_res_dvar_dy = 0, max_res_dvar_dp = 0; + // return only the requested variables list after computation + for (auto& var_fcn : functions->var_fcns) { + max_res_size = std::max(max_res_size, size_t(var_fcn->out_shape(0))); + length_of_return_vector += var_fcn->nnz_out(); + for (auto& dvar_fcn : functions->dvar_dy_fcns) { + max_res_dvar_dy = std::max(max_res_dvar_dy, size_t(dvar_fcn->out_shape(0))); + } + for (auto& dvar_fcn : functions->dvar_dp_fcns) { + max_res_dvar_dp = std::max(max_res_dvar_dp, size_t(dvar_fcn->out_shape(0))); + } + + res.resize(max_res_size); + res_dvar_dy.resize(max_res_dvar_dy); + res_dvar_dp.resize(max_res_dvar_dp); + } + + return length_of_return_vector; +} + +template +void IDAKLUSolverOpenMP::SetSolverOptions() { + // Maximum order of the linear multistep method + CheckErrors(IDASetMaxOrd(ida_mem, solver_opts.max_order_bdf)); + + // Maximum number of steps to be taken by the solver in its attempt to reach + // the next output time + CheckErrors(IDASetMaxNumSteps(ida_mem, solver_opts.max_num_steps)); + + // Initial step size + CheckErrors(IDASetInitStep(ida_mem, solver_opts.dt_init)); + + // Maximum absolute step size + CheckErrors(IDASetMaxStep(ida_mem, solver_opts.dt_max)); + + // Maximum number of error test failures in attempting one step + CheckErrors(IDASetMaxErrTestFails(ida_mem, solver_opts.max_error_test_failures)); + + // Maximum number of nonlinear solver iterations at one step + CheckErrors(IDASetMaxNonlinIters(ida_mem, solver_opts.max_nonlinear_iterations)); + + // Maximum number of nonlinear solver convergence failures at one step + CheckErrors(IDASetMaxConvFails(ida_mem, solver_opts.max_convergence_failures)); + + // Safety factor in the nonlinear convergence test + CheckErrors(IDASetNonlinConvCoef(ida_mem, solver_opts.nonlinear_convergence_coefficient)); + + // Suppress algebraic variables from error test + CheckErrors(IDASetSuppressAlg(ida_mem, solver_opts.suppress_algebraic_error)); + + // Positive constant in the Newton iteration convergence test within the initial + // condition calculation + CheckErrors(IDASetNonlinConvCoefIC(ida_mem, solver_opts.nonlinear_convergence_coefficient_ic)); + + // Maximum number of steps allowed when icopt=IDA_YA_YDP_INIT in IDACalcIC + CheckErrors(IDASetMaxNumStepsIC(ida_mem, solver_opts.max_num_steps_ic)); + + // Maximum number of the approximate Jacobian or preconditioner evaluations + // allowed when the Newton iteration appears to be slowly converging + CheckErrors(IDASetMaxNumJacsIC(ida_mem, solver_opts.max_num_jacobians_ic)); + + // Maximum number of Newton iterations allowed in any one attempt to solve + // the initial conditions calculation problem + CheckErrors(IDASetMaxNumItersIC(ida_mem, solver_opts.max_num_iterations_ic)); + + // Maximum number of linesearch backtracks allowed in any Newton iteration, + // when solving the initial conditions calculation problem + CheckErrors(IDASetMaxBacksIC(ida_mem, solver_opts.max_linesearch_backtracks_ic)); + + // Turn off linesearch + CheckErrors(IDASetLineSearchOffIC(ida_mem, solver_opts.linesearch_off_ic)); + + // Ratio between linear and nonlinear tolerances + CheckErrors(IDASetEpsLin(ida_mem, solver_opts.epsilon_linear_tolerance)); + + // Increment factor used in DQ Jv approximation + CheckErrors(IDASetIncrementFactor(ida_mem, solver_opts.increment_factor)); + + int LS_type = SUNLinSolGetType(LS); + if (LS_type == SUNLINEARSOLVER_DIRECT || LS_type == SUNLINEARSOLVER_MATRIX_ITERATIVE) { + // Enable or disable linear solution scaling + CheckErrors(IDASetLinearSolutionScaling(ida_mem, solver_opts.linear_solution_scaling)); + } +} + +template +void IDAKLUSolverOpenMP::SetMatrix() { + // Create Matrix object + if (setup_opts.jacobian == "sparse") { + DEBUG("\tsetting sparse matrix"); + J = SUNSparseMatrix( + number_of_states, + number_of_states, + jac_times_cjmass_nnz, + CSC_MAT, + sunctx + ); + } else if (setup_opts.jacobian == "banded") { + DEBUG("\tsetting banded matrix"); + J = SUNBandMatrix( + number_of_states, + jac_bandwidth_upper, + jac_bandwidth_lower, + sunctx + ); + } else if (setup_opts.jacobian == "dense" || setup_opts.jacobian == "none") { + DEBUG("\tsetting dense matrix"); + J = SUNDenseMatrix( + number_of_states, + number_of_states, + sunctx + ); + } else if (setup_opts.jacobian == "matrix-free") { + DEBUG("\tsetting matrix-free"); + J = NULL; + } else { + throw std::invalid_argument("Unsupported matrix requested"); + } +} + +template +void IDAKLUSolverOpenMP::Initialize() { + // Call after setting the solver + + // attach the linear solver + if (LS == nullptr) { + throw std::invalid_argument("Linear solver not set"); + } + CheckErrors(IDASetLinearSolver(ida_mem, LS, J)); + + if (setup_opts.preconditioner != "none") { + DEBUG("\tsetting IDADDB preconditioner"); + // setup preconditioner + CheckErrors(IDABBDPrecInit( + ida_mem, number_of_states, setup_opts.precon_half_bandwidth, + setup_opts.precon_half_bandwidth, setup_opts.precon_half_bandwidth_keep, + setup_opts.precon_half_bandwidth_keep, 0.0, residual_eval_approx, NULL)); + } + + if (setup_opts.jacobian == "matrix-free") { + CheckErrors(IDASetJacTimes(ida_mem, NULL, jtimes_eval)); + } else if (setup_opts.jacobian != "none") { + CheckErrors(IDASetJacFn(ida_mem, jacobian_eval)); + } + + if (sensitivity) { + CheckErrors(IDASensInit(ida_mem, number_of_parameters, IDA_SIMULTANEOUS, + sensitivities_eval, yyS, yypS)); + CheckErrors(IDASensEEtolerances(ida_mem)); + } + + CheckErrors(SUNLinSolInitialize(LS)); + + auto id_np_val = rhs_alg_id.unchecked<1>(); + realtype *id_val; + id_val = N_VGetArrayPointer(id); + + // Determine if the system is an ODE + is_ODE = number_of_states > 0; + for (int ii = 0; ii < number_of_states; ii++) { + id_val[ii] = id_np_val[ii]; + // check if id_val[ii] approximately equals 1 (>0.999) handles + // cases where id_val[ii] is not exactly 1 due to numerical errors + is_ODE &= id_val[ii] > 0.999; + } + + // Variable types: differential (1) and algebraic (0) + CheckErrors(IDASetId(ida_mem, id)); +} + +template +IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { + DEBUG("IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP"); + // Free memory + if (sensitivity) { + IDASensFree(ida_mem); + } + + CheckErrors(SUNLinSolFree(LS)); + + SUNMatDestroy(J); + N_VDestroy(avtol); + N_VDestroy(yy); + N_VDestroy(yyp); + N_VDestroy(y_cache); + N_VDestroy(id); + + if (sensitivity) { + N_VDestroyVectorArray(yyS, number_of_parameters); + N_VDestroyVectorArray(yypS, number_of_parameters); + } + + IDAFree(&ida_mem); + SUNContext_Free(&sunctx); +} + +template +SolutionData IDAKLUSolverOpenMP::solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps +) +{ + DEBUG("IDAKLUSolver::solve"); + const int number_of_evals = t_eval.size(); + const int number_of_interps = t_interp.size(); + + // Hermite interpolation is only available when saving + // 1. adaptive steps and 2. the full solution + save_hermite = ( + solver_opts.hermite_interpolation && + save_adaptive_steps && + !save_outputs_only + ); + + InitializeStorage(number_of_evals + number_of_interps); + + int i_save = 0; + + realtype t0 = t_eval.front(); + realtype tf = t_eval.back(); + + realtype t_val = t0; + realtype t_prev = t0; + int i_eval = 0; + + realtype t_interp_next; + int i_interp = 0; + // If t_interp is empty, save all adaptive steps + if (save_interp_steps) { + t_interp_next = t_interp[0]; + } + + auto n_coeffs = number_of_states + number_of_parameters * number_of_states; + + // set inputs + for (int i = 0; i < functions->inputs.size(); i++) { + functions->inputs[i] = inputs[i]; + } + + // Setup consistent initialization + realtype *y_val = N_VGetArrayPointer(yy); + realtype *yp_val = N_VGetArrayPointer(yyp); + vector yS_val(number_of_parameters); + vector ypS_val(number_of_parameters); + for (int p = 0 ; p < number_of_parameters; p++) { + yS_val[p] = N_VGetArrayPointer(yyS[p]); + ypS_val[p] = N_VGetArrayPointer(yypS[p]); + for (int i = 0; i < number_of_states; i++) { + yS_val[p][i] = y0[i + (p + 1) * number_of_states]; + ypS_val[p][i] = yp0[i + (p + 1) * number_of_states]; + } + } + + for (int i = 0; i < number_of_states; i++) { + y_val[i] = y0[i]; + yp_val[i] = yp0[i]; + } + + SetSolverOptions(); + + // Prepare first time step + i_eval = 1; + realtype t_eval_next = t_eval[i_eval]; + + + // Consistent initialization + ReinitializeIntegrator(t0); + int const init_type = solver_opts.init_all_y_ic ? IDA_Y_INIT : IDA_YA_YDP_INIT; + if (solver_opts.calc_ic) { + ConsistentInitialization(t0, t_eval_next, init_type); + } + + // Set the initial stop time + IDASetStopTime(ida_mem, t_eval_next); + + // Progress one step. This must be done before the while loop to ensure + // that we can run IDAGetDky at t0 for dky = 1 + int retval = IDASolve(ida_mem, tf, &t_val, yy, yyp, IDA_ONE_STEP); + + // Store consistent initialization + CheckErrors(IDAGetDky(ida_mem, t0, 0, yy)); + if (sensitivity) { + CheckErrors(IDAGetSensDky(ida_mem, t0, 0, yyS)); + } + + SetStep(t0, y_val, yp_val, yS_val, ypS_val, i_save); + + // Reset the states at t = t_val. Sensitivities are handled in the while-loop + CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy)); + + // Solve the system + DEBUG("IDASolve"); + while (true) { + if (retval < 0) { + // failed + break; + } else if (t_prev == t_val) { + // IDA sometimes returns an identical time point twice + // instead of erroring. Assign a retval and break + retval = IDA_ERR_FAIL; + break; + } + + bool hit_tinterp = save_interp_steps && t_interp_next >= t_prev; + bool hit_teval = retval == IDA_TSTOP_RETURN; + bool hit_final_time = t_val >= tf || (hit_teval && i_eval == number_of_evals); + bool hit_event = retval == IDA_ROOT_RETURN; + bool hit_adaptive = save_adaptive_steps && retval == IDA_SUCCESS; + + if (sensitivity) { + CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS)); + } + + if (hit_tinterp) { + // Save the interpolated state at t_prev < t < t_val, for all t in t_interp + SetStepInterp( + i_interp, + t_interp_next, + t_interp, + t_val, + t_prev, + t_eval_next, + y_val, + yp_val, + yS_val, + ypS_val, + i_save); + } + + if (hit_adaptive || hit_teval || hit_event || hit_final_time) { + if (hit_tinterp) { + // Reset the states and sensitivities at t = t_val + CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy)); + if (sensitivity) { + CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS)); + } + } + + // Save the current state at t_val + // First, check to make sure that the t_val is not equal to the current t value + // If it is, we don't want to save the current state twice + if (!hit_tinterp || t_val != t.back()) { + if (hit_adaptive) { + // Dynamically allocate memory for the adaptive step + ExtendAdaptiveArrays(); + } + + SetStep(t_val, y_val, yp_val, yS_val, ypS_val, i_save); + } + } + + if (hit_final_time || hit_event) { + // Successful simulation. Exit the while loop + break; + } else if (hit_teval) { + // Set the next stop time + i_eval++; + t_eval_next = t_eval[i_eval]; + CheckErrors(IDASetStopTime(ida_mem, t_eval_next)); + + // Reinitialize the solver to deal with the discontinuity at t = t_val. + ReinitializeIntegrator(t_val); + ConsistentInitialization(t_val, t_eval_next, IDA_YA_YDP_INIT); + } + + t_prev = t_val; + + // Progress one step + retval = IDASolve(ida_mem, tf, &t_val, yy, yyp, IDA_ONE_STEP); + } + + int const length_of_final_sv_slice = save_outputs_only ? number_of_states : 0; + realtype *yterm_return = new realtype[length_of_final_sv_slice]; + if (save_outputs_only) { + // store final state slice if outout variables are specified + std::memcpy(yterm_return, y_val, length_of_final_sv_slice * sizeof(realtype*)); + } + + if (solver_opts.print_stats) { + PrintStats(); + } + + // store number of timesteps so we can generate the solution later + number_of_timesteps = i_save; + + // Copy the data to return as numpy arrays + + // Time, t + realtype *t_return = new realtype[number_of_timesteps]; + for (size_t i = 0; i < number_of_timesteps; i++) { + t_return[i] = t[i]; + } + + // States, y + realtype *y_return = new realtype[number_of_timesteps * length_of_return_vector]; + int count = 0; + for (size_t i = 0; i < number_of_timesteps; i++) { + for (size_t j = 0; j < length_of_return_vector; j++) { + y_return[count] = y[i][j]; + count++; + } + } + + // Sensitivity states, yS + // Note: Ordering of vector is different if computing outputs vs returning + // the complete state vector + auto const arg_sens0 = (save_outputs_only ? number_of_timesteps : number_of_parameters); + auto const arg_sens1 = (save_outputs_only ? length_of_return_vector : number_of_timesteps); + auto const arg_sens2 = (save_outputs_only ? number_of_parameters : length_of_return_vector); + + realtype *yS_return = new realtype[arg_sens0 * arg_sens1 * arg_sens2]; + count = 0; + for (size_t idx0 = 0; idx0 < arg_sens0; idx0++) { + for (size_t idx1 = 0; idx1 < arg_sens1; idx1++) { + for (size_t idx2 = 0; idx2 < arg_sens2; idx2++) { + auto i = (save_outputs_only ? idx0 : idx1); + auto j = (save_outputs_only ? idx1 : idx2); + auto k = (save_outputs_only ? idx2 : idx0); + + yS_return[count] = yS[i][k][j]; + count++; + } + } + } + + realtype *yp_return = new realtype[(save_hermite ? 1 : 0) * (number_of_timesteps * number_of_states)]; + realtype *ypS_return = new realtype[(save_hermite ? 1 : 0) * (arg_sens0 * arg_sens1 * arg_sens2)]; + if (save_hermite) { + count = 0; + for (size_t i = 0; i < number_of_timesteps; i++) { + for (size_t j = 0; j < number_of_states; j++) { + yp_return[count] = yp[i][j]; + count++; + } + } + + // Sensitivity states, ypS + // Note: Ordering of vector is different if computing outputs vs returning + // the complete state vector + count = 0; + for (size_t idx0 = 0; idx0 < arg_sens0; idx0++) { + for (size_t idx1 = 0; idx1 < arg_sens1; idx1++) { + for (size_t idx2 = 0; idx2 < arg_sens2; idx2++) { + auto i = (save_outputs_only ? idx0 : idx1); + auto j = (save_outputs_only ? idx1 : idx2); + auto k = (save_outputs_only ? idx2 : idx0); + + ypS_return[count] = ypS[i][k][j]; + count++; + } + } + } + } + + return SolutionData( + retval, + number_of_timesteps, + length_of_return_vector, + arg_sens0, + arg_sens1, + arg_sens2, + length_of_final_sv_slice, + save_hermite, + t_return, + y_return, + yp_return, + yS_return, + ypS_return, + yterm_return); +} + +template +void IDAKLUSolverOpenMP::ExtendAdaptiveArrays() { + DEBUG("IDAKLUSolver::ExtendAdaptiveArrays"); + // Time + t.emplace_back(0.0); + + // States + y.emplace_back(length_of_return_vector, 0.0); + + // Sensitivity + if (sensitivity) { + yS.emplace_back(number_of_parameters, vector(length_of_return_vector, 0.0)); + } + + if (save_hermite) { + ExtendHermiteArrays(); + } +} + +template +void IDAKLUSolverOpenMP::ExtendHermiteArrays() { + DEBUG("IDAKLUSolver::ExtendHermiteArrays"); + // States + yp.emplace_back(number_of_states, 0.0); + + // Sensitivity + if (sensitivity) { + ypS.emplace_back(number_of_parameters, vector(number_of_states, 0.0)); + } +} + +template +void IDAKLUSolverOpenMP::ReinitializeIntegrator(const realtype& t_val) { + DEBUG("IDAKLUSolver::ReinitializeIntegrator"); + CheckErrors(IDAReInit(ida_mem, t_val, yy, yyp)); + if (sensitivity) { + CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, yypS)); + } +} + +template +void IDAKLUSolverOpenMP::ConsistentInitialization( + const realtype& t_val, + const realtype& t_next, + const int& icopt) { + DEBUG("IDAKLUSolver::ConsistentInitialization"); + + if (is_ODE && icopt == IDA_YA_YDP_INIT) { + ConsistentInitializationODE(t_val); + } else { + ConsistentInitializationDAE(t_val, t_next, icopt); + } +} + +template +void IDAKLUSolverOpenMP::ConsistentInitializationDAE( + const realtype& t_val, + const realtype& t_next, + const int& icopt) { + DEBUG("IDAKLUSolver::ConsistentInitializationDAE"); + IDACalcIC(ida_mem, icopt, t_next); +} + +template +void IDAKLUSolverOpenMP::ConsistentInitializationODE( + const realtype& t_val) { + DEBUG("IDAKLUSolver::ConsistentInitializationODE"); + + // For ODEs where the mass matrix M = I, we can simplify the problem + // by analytically computing the yp values. If we take our implicit + // DAE system res(t,y,yp) = f(t,y) - I*yp, then yp = res(t,y,0). This + // avoids an expensive call to IDACalcIC. + realtype *y_cache_val = N_VGetArrayPointer(y_cache); + std::memset(y_cache_val, 0, number_of_states * sizeof(realtype)); + // Overwrite yp + residual_eval(t_val, yy, y_cache, yyp, functions.get()); +} + +template +void IDAKLUSolverOpenMP::SetStep( + realtype &tval, + realtype *y_val, + realtype *yp_val, + vector const &yS_val, + vector const &ypS_val, + int &i_save +) { + // Set adaptive step results for y and yS + DEBUG("IDAKLUSolver::SetStep"); + + // Time + t[i_save] = tval; + + if (save_outputs_only) { + SetStepOutput(tval, y_val, yS_val, i_save); + } else { + SetStepFull(tval, y_val, yS_val, i_save); + + if (save_hermite) { + SetStepHermite(tval, yp_val, ypS_val, i_save); + } + } + + i_save++; +} + + +template +void IDAKLUSolverOpenMP::SetStepInterp( + int &i_interp, + realtype &t_interp_next, + vector const &t_interp, + realtype &t_val, + realtype &t_prev, + realtype const &t_eval_next, + realtype *y_val, + realtype *yp_val, + vector const &yS_val, + vector const &ypS_val, + int &i_save + ) { + // Save the state at the requested time + DEBUG("IDAKLUSolver::SetStepInterp"); + + while (i_interp <= (t_interp.size()-1) && t_interp_next <= t_val) { + CheckErrors(IDAGetDky(ida_mem, t_interp_next, 0, yy)); + if (sensitivity) { + CheckErrors(IDAGetSensDky(ida_mem, t_interp_next, 0, yyS)); + } + + // Memory is already allocated for the interpolated values + SetStep(t_interp_next, y_val, yp_val, yS_val, ypS_val, i_save); + + i_interp++; + if (i_interp == (t_interp.size())) { + // Reached the final t_interp value + break; + } + t_interp_next = t_interp[i_interp]; + } +} + +template +void IDAKLUSolverOpenMP::SetStepFull( + realtype &tval, + realtype *y_val, + vector const &yS_val, + int &i_save +) { + // Set adaptive step results for y and yS + DEBUG("IDAKLUSolver::SetStepFull"); + + // States + auto &y_back = y[i_save]; + for (size_t j = 0; j < number_of_states; ++j) { + y_back[j] = y_val[j]; + } + + // Sensitivity + if (sensitivity) { + SetStepFullSensitivities(tval, y_val, yS_val, i_save); + } +} + +template +void IDAKLUSolverOpenMP::SetStepFullSensitivities( + realtype &tval, + realtype *y_val, + vector const &yS_val, + int &i_save +) { + DEBUG("IDAKLUSolver::SetStepFullSensitivities"); + + // Calculate sensitivities for the full yS array + for (size_t j = 0; j < number_of_parameters; ++j) { + auto &yS_back_j = yS[i_save][j]; + auto &ySval_j = yS_val[j]; + for (size_t k = 0; k < number_of_states; ++k) { + yS_back_j[k] = ySval_j[k]; + } + } +} + +template +void IDAKLUSolverOpenMP::SetStepOutput( + realtype &tval, + realtype *y_val, + const vector& yS_val, + int &i_save +) { + DEBUG("IDAKLUSolver::SetStepOutput"); + // Evaluate functions for each requested variable and store + + size_t j = 0; + for (auto& var_fcn : functions->var_fcns) { + (*var_fcn)({&tval, y_val, functions->inputs.data()}, {&res[0]}); + // store in return vector + for (size_t jj=0; jjnnz_out(); jj++) { + y[i_save][j++] = res[jj]; + } + } + // calculate sensitivities + if (sensitivity) { + SetStepOutputSensitivities(tval, y_val, yS_val, i_save); + } +} + +template +void IDAKLUSolverOpenMP::SetStepOutputSensitivities( + realtype &tval, + realtype *y_val, + const vector& yS_val, + int &i_save + ) { + DEBUG("IDAKLUSolver::SetStepOutputSensitivities"); + // Calculate sensitivities + vector dens_dvar_dp = vector(number_of_parameters, 0); + for (size_t dvar_k=0; dvar_kdvar_dy_fcns.size(); dvar_k++) { + // Isolate functions + Expression* dvar_dy = functions->dvar_dy_fcns[dvar_k]; + Expression* dvar_dp = functions->dvar_dp_fcns[dvar_k]; + // Calculate dvar/dy + (*dvar_dy)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dy[0]}); + // Calculate dvar/dp and convert to dense array for indexing + (*dvar_dp)({&tval, y_val, functions->inputs.data()}, {&res_dvar_dp[0]}); + for (int k=0; knnz_out(); k++) { + dens_dvar_dp[dvar_dp->get_row()[k]] = res_dvar_dp[k]; + } + // Calculate sensitivities + for (int paramk=0; paramknnz_out(); spk++) { + yS_back_paramk[dvar_k] += res_dvar_dy[spk] * yS_val[paramk][dvar_dy->get_col()[spk]]; + } + } + } +} + +template +void IDAKLUSolverOpenMP::SetStepHermite( + realtype &tval, + realtype *yp_val, + vector const &ypS_val, + int &i_save +) { + // Set adaptive step results for yp and ypS + DEBUG("IDAKLUSolver::SetStepHermite"); + + // States + CheckErrors(IDAGetDky(ida_mem, tval, 1, yyp)); + auto &yp_back = yp[i_save]; + for (size_t j = 0; j < length_of_return_vector; ++j) { + yp_back[j] = yp_val[j]; + + } + + // Sensitivity + if (sensitivity) { + SetStepHermiteSensitivities(tval, yp_val, ypS_val, i_save); + } +} + +template +void IDAKLUSolverOpenMP::SetStepHermiteSensitivities( + realtype &tval, + realtype *yp_val, + vector const &ypS_val, + int &i_save +) { + DEBUG("IDAKLUSolver::SetStepHermiteSensitivities"); + + // Calculate sensitivities for the full ypS array + CheckErrors(IDAGetSensDky(ida_mem, tval, 1, yypS)); + for (size_t j = 0; j < number_of_parameters; ++j) { + auto &ypS_back_j = ypS[i_save][j]; + auto &ypSval_j = ypS_val[j]; + for (size_t k = 0; k < number_of_states; ++k) { + ypS_back_j[k] = ypSval_j[k]; + } + } +} + +template +void IDAKLUSolverOpenMP::CheckErrors(int const & flag) { + if (flag < 0) { + auto message = std::string("IDA failed with flag ") + std::to_string(flag); + throw std::runtime_error(message.c_str()); + } +} + +template +void IDAKLUSolverOpenMP::PrintStats() { + long nsteps, nrevals, nlinsetups, netfails; + int klast, kcur; + realtype hinused, hlast, hcur, tcur; + + CheckErrors(IDAGetIntegratorStats( + ida_mem, + &nsteps, + &nrevals, + &nlinsetups, + &netfails, + &klast, + &kcur, + &hinused, + &hlast, + &hcur, + &tcur + )); + + long nniters, nncfails; + CheckErrors(IDAGetNonlinSolvStats(ida_mem, &nniters, &nncfails)); + + long int ngevalsBBDP = 0; + if (setup_opts.using_iterative_solver) { + CheckErrors(IDABBDPrecGetNumGfnEvals(ida_mem, &ngevalsBBDP)); + } + + py::print("Solver Stats:"); + py::print("\tNumber of steps =", nsteps); + py::print("\tNumber of calls to residual function =", nrevals); + py::print("\tNumber of calls to residual function in preconditioner =", + ngevalsBBDP); + py::print("\tNumber of linear solver setup calls =", nlinsetups); + py::print("\tNumber of error test failures =", netfails); + py::print("\tMethod order used on last step =", klast); + py::print("\tMethod order used on next step =", kcur); + py::print("\tInitial step size =", hinused); + py::print("\tStep size on last step =", hlast); + py::print("\tStep size on next step =", hcur); + py::print("\tCurrent internal time reached =", tcur); + py::print("\tNumber of nonlinear iterations performed =", nniters); + py::print("\tNumber of nonlinear convergence failures =", nncfails); +} diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.cpp b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.cpp new file mode 100644 index 0000000..45ceed0 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.cpp @@ -0,0 +1 @@ +#include "IDAKLUSolverOpenMP_solvers.hpp" diff --git a/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.hpp b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.hpp new file mode 100644 index 0000000..5f6f29b --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IDAKLUSolverOpenMP_solvers.hpp @@ -0,0 +1,131 @@ +#ifndef PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP +#define PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP + +#include "IDAKLUSolverOpenMP.hpp" + +/** + * @brief IDAKLUSolver Dense implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_Dense : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_Dense(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_Dense(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver KLU implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_KLU : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_KLU(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_KLU(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver Banded implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_Band : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_Band(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_Band(Base::yy, Base::J, Base::sunctx); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPBCGS implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPBCGS : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPBCGS(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPBCGS( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPFGMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPFGMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPFGMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPFGMR( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPGMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPGMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPGMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPGMR( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +/** + * @brief IDAKLUSolver SPTFQMR implementation with OpenMP class + */ +template +class IDAKLUSolverOpenMP_SPTFQMR : public IDAKLUSolverOpenMP { +public: + using Base = IDAKLUSolverOpenMP; + template + IDAKLUSolverOpenMP_SPTFQMR(Args&& ... args) : Base(std::forward(args) ...) + { + Base::LS = SUNLinSol_SPTFQMR( + Base::yy, + Base::precon_type, + Base::setup_opts.linsol_max_iterations, + Base::sunctx + ); + Base::Initialize(); + } +}; + +#endif // PYBAMM_IDAKLU_CASADI_SOLVER_OPENMP_HPP diff --git a/src/pybammsolvers/idaklu_source/IdakluJax.cpp b/src/pybammsolvers/idaklu_source/IdakluJax.cpp new file mode 100644 index 0000000..15c2b2d --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IdakluJax.cpp @@ -0,0 +1,264 @@ +#include "IdakluJax.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include + +// Initialise static variable +std::int64_t IdakluJax::universal_count = 0; + +// Repository of instantiated IdakluJax objects +std::map idaklu_jax_instances; + +// Create a new IdakluJax object, assign identifier, add to the objects list and return as pointer +IdakluJax *create_idaklu_jax() { + IdakluJax *p = new IdakluJax(); + idaklu_jax_instances[p->get_index()] = p; + return p; +} + +IdakluJax::IdakluJax() { + index = universal_count++; +} + +IdakluJax::~IdakluJax() { + idaklu_jax_instances.erase(index); +} + +void IdakluJax::register_callback_eval(CallbackEval h) { + callback_eval = h; +} + +void IdakluJax::register_callback_jvp(CallbackJvp h) { + callback_jvp = h; +} + +void IdakluJax::register_callback_vjp(CallbackVjp h) { + callback_vjp = h; +} + +void IdakluJax::register_callbacks(CallbackEval h_eval, CallbackJvp h_jvp, CallbackVjp h_vjp) { + register_callback_eval(h_eval); + register_callback_jvp(h_jvp); + register_callback_vjp(h_vjp); +} + +void IdakluJax::cpu_idaklu_eval(void *out_tuple, const void **in) { + // Parse the inputs --- note that these come from jax lowering and are NOT np_array's + int k = 1; // Start indexing at 1 to skip idaklu_jax index + const std::int64_t n_t = *reinterpret_cast(in[k++]); + const std::int64_t n_vars = *reinterpret_cast(in[k++]); + const std::int64_t n_inputs = *reinterpret_cast(in[k++]); + const realtype *t = reinterpret_cast(in[k++]); + realtype *inputs = new realtype(n_inputs); + for (int i = 0; i < n_inputs; i++) + inputs[i] = reinterpret_cast(in[k++])[0]; + void *out = reinterpret_cast(out_tuple); + + // Log + DEBUG("cpu_idaklu"); + DEBUG_n(index); + DEBUG_n(n_t); + DEBUG_n(n_vars); + DEBUG_n(n_inputs); + DEBUG_v(t, n_t); + DEBUG_v(inputs, n_inputs); + + // Acquire GIL (since this function is called as a capsule) + py::gil_scoped_acquire acquire; + PyGILState_STATE state = PyGILState_Ensure(); + + // Convert time vector to an np_array + py::capsule t_capsule(t, "t_capsule"); + np_array t_np = np_array({n_t}, {sizeof(realtype)}, t, t_capsule); + + // Convert inputs to an np_array + py::capsule in_capsule(inputs, "in_capsule"); + np_array in_np = np_array({n_inputs}, {sizeof(realtype)}, inputs, in_capsule); + + // Call solve function in python to obtain an np_array + np_array out_np = callback_eval(t_np, in_np); + auto out_buf = out_np.request(); + const realtype *out_ptr = reinterpret_cast(out_buf.ptr); + + // Arrange into 'out' array + memcpy(out, out_ptr, n_t * n_vars * sizeof(realtype)); + + // Release GIL + PyGILState_Release(state); +} + +void IdakluJax::cpu_idaklu_jvp(void *out_tuple, const void **in) { + // Parse the inputs --- note that these come from jax lowering and are NOT np_array's + int k = 1; // Start indexing at 1 to skip idaklu_jax index + const std::int64_t n_t = *reinterpret_cast(in[k++]); + const std::int64_t n_vars = *reinterpret_cast(in[k++]); + const std::int64_t n_inputs = *reinterpret_cast(in[k++]); + const realtype *primal_t = reinterpret_cast(in[k++]); + realtype *primal_inputs = new realtype(n_inputs); + for (int i = 0; i < n_inputs; i++) + primal_inputs[i] = reinterpret_cast(in[k++])[0]; + const realtype *tangent_t = reinterpret_cast(in[k++]); + realtype *tangent_inputs = new realtype(n_inputs); + for (int i = 0; i < n_inputs; i++) + tangent_inputs[i] = reinterpret_cast(in[k++])[0]; + void *out = reinterpret_cast(out_tuple); + + // Log + DEBUG("cpu_idaklu_jvp"); + DEBUG_n(n_t); + DEBUG_n(n_vars); + DEBUG_n(n_inputs); + DEBUG_v(primal_t, n_t); + DEBUG_v(primal_inputs, n_inputs); + DEBUG_v(tangent_t, n_t); + DEBUG_v(tangent_inputs, n_inputs); + + // Acquire GIL (since this function is called as a capsule) + py::gil_scoped_acquire acquire; + PyGILState_STATE state = PyGILState_Ensure(); + + // Form primals time vector as np_array + py::capsule primal_t_capsule(primal_t, "primal_t_capsule"); + np_array primal_t_np = np_array( + {n_t}, + {sizeof(realtype)}, + primal_t, + primal_t_capsule + ); + + // Pack primals as np_array + py::capsule primal_inputs_capsule(primal_inputs, "primal_inputs_capsule"); + np_array primal_inputs_np = np_array( + {n_inputs}, + {sizeof(realtype)}, + primal_inputs, + primal_inputs_capsule + ); + + // Form tangents time vector as np_array + py::capsule tangent_t_capsule(tangent_t, "tangent_t_capsule"); + np_array tangent_t_np = np_array( + {n_t}, + {sizeof(realtype)}, + tangent_t, + tangent_t_capsule + ); + + // Pack tangents as np_array + py::capsule tangent_inputs_capsule(tangent_inputs, "tangent_inputs_capsule"); + np_array tangent_inputs_np = np_array( + {n_inputs}, + {sizeof(realtype)}, + tangent_inputs, + tangent_inputs_capsule + ); + + // Call JVP function in python to obtain an np_array + np_array y_dot = callback_jvp( + primal_t_np, primal_inputs_np, + tangent_t_np, tangent_inputs_np + ); + auto buf = y_dot.request(); + const realtype *ptr = reinterpret_cast(buf.ptr); + + // Arrange into 'out' array + memcpy(out, ptr, n_t * n_vars * sizeof(realtype)); + + // Release GIL + PyGILState_Release(state); +} + +void IdakluJax::cpu_idaklu_vjp(void *out_tuple, const void **in) { + int k = 1; // Start indexing at 1 to skip idaklu_jax index + const std::int64_t n_t = *reinterpret_cast(in[k++]); + const std::int64_t n_inputs = *reinterpret_cast(in[k++]); + const std::int64_t n_y_bar0 = *reinterpret_cast(in[k++]); + const std::int64_t n_y_bar1 = *reinterpret_cast(in[k++]); + const std::int64_t n_y_bar = (n_y_bar1 > 0) ? (n_y_bar0*n_y_bar1) : n_y_bar0; + const realtype *y_bar = reinterpret_cast(in[k++]); + const std::int64_t *invar = reinterpret_cast(in[k++]); + const realtype *t = reinterpret_cast(in[k++]); + realtype *inputs = new realtype(n_inputs); + for (int i = 0; i < n_inputs; i++) + inputs[i] = reinterpret_cast(in[k++])[0]; + realtype *out = reinterpret_cast(out_tuple); + + // Log + DEBUG("cpu_idaklu_vjp"); + DEBUG_n(n_t); + DEBUG_n(n_inputs); + DEBUG_n(n_y_bar0); + DEBUG_n(n_y_bar1); + DEBUG_v(y_bar, n_y_bar0*n_y_bar1); + DEBUG_v(invar, 1); + DEBUG_v(t, n_t); + DEBUG_v(inputs, n_inputs); + + // Acquire GIL (since this function is called as a capsule) + py::gil_scoped_acquire acquire; + PyGILState_STATE state = PyGILState_Ensure(); + + // Convert time vector to an np_array + py::capsule t_capsule(t, "t_capsule"); + np_array t_np = np_array({n_t}, {sizeof(realtype)}, t, t_capsule); + + // Convert y_bar to an np_array + py::capsule y_bar_capsule(y_bar, "y_bar_capsule"); + np_array y_bar_np = np_array( + {n_y_bar}, + {sizeof(realtype)}, + y_bar, + y_bar_capsule + ); + + // Convert inputs to an np_array + py::capsule in_capsule(inputs, "in_capsule"); + np_array in_np = np_array({n_inputs}, {sizeof(realtype)}, inputs, in_capsule); + + // Call VJP function in python to obtain an np_array + np_array y_dot = callback_vjp(y_bar_np, n_y_bar0, n_y_bar1, invar[0], t_np, in_np); + auto buf = y_dot.request(); + const realtype *ptr = reinterpret_cast(buf.ptr); + + // Arrange output + //memcpy(out, ptr, sizeof(realtype)); + out[0] = ptr[0]; // output is scalar + + // Release GIL + PyGILState_Release(state); +} + +template +pybind11::capsule EncapsulateFunction(T* fn) { + return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +void wrap_cpu_idaklu_eval_f64(void *out_tuple, const void **in) { + const std::int64_t index = *reinterpret_cast(in[0]); + idaklu_jax_instances[index]->cpu_idaklu_eval(out_tuple, in); +} + +void wrap_cpu_idaklu_jvp_f64(void *out_tuple, const void **in) { + const std::int64_t index = *reinterpret_cast(in[0]); + idaklu_jax_instances[index]->cpu_idaklu_jvp(out_tuple, in); +} + +void wrap_cpu_idaklu_vjp_f64(void *out_tuple, const void **in) { + const std::int64_t index = *reinterpret_cast(in[0]); + idaklu_jax_instances[index]->cpu_idaklu_vjp(out_tuple, in); +} + +pybind11::dict Registrations() { + pybind11::dict dict; + dict["cpu_idaklu_f64"] = EncapsulateFunction(wrap_cpu_idaklu_eval_f64); + dict["cpu_idaklu_jvp_f64"] = EncapsulateFunction(wrap_cpu_idaklu_jvp_f64); + dict["cpu_idaklu_vjp_f64"] = EncapsulateFunction(wrap_cpu_idaklu_vjp_f64); + return dict; +} diff --git a/src/pybammsolvers/idaklu_source/IdakluJax.hpp b/src/pybammsolvers/idaklu_source/IdakluJax.hpp new file mode 100644 index 0000000..251567d --- /dev/null +++ b/src/pybammsolvers/idaklu_source/IdakluJax.hpp @@ -0,0 +1,117 @@ +#ifndef PYBAMM_IDAKLU_JAX_SOLVER_HPP +#define PYBAMM_IDAKLU_JAX_SOLVER_HPP + +#include "common.hpp" + +/** + * @brief Callback function type for JAX evaluation + */ +using CallbackEval = std::function; + +/** + * @brief Callback function type for JVP evaluation + */ +using CallbackJvp = std::function; + +/** + * @brief Callback function type for VJP evaluation + */ +using CallbackVjp = std::function; + +/** + * @brief IDAKLU-JAX interface class. + * + * This class provides an interface to the IDAKLU-JAX solver. It is called by the + * IDAKLUJax class in python and provides the lowering rules for the JAX evaluation, + * JVP and VJP primitive functions. Each of these make use of the IDAKLU solver via + * the IDAKLUSolver python class, so this IDAKLUJax class provides a wrapper which + * redirects calls back to the IDAKLUSolver via python callbacks. + */ +class IdakluJax { +private: + static std::int64_t universal_count; // Universal count for IdakluJax objects + std::int64_t index; // Instance index +public: + /** + * @brief Constructor + */ + IdakluJax(); + + /** + * @brief Destructor + */ + ~IdakluJax(); + + /** + * @brief Callback for JAX evaluation + */ + CallbackEval callback_eval; + + /** + * @brief Callback for JVP evaluation + */ + CallbackJvp callback_jvp; + + /** + * @brief Callback for VJP evaluation + */ + CallbackVjp callback_vjp; + + /** + * @brief Register callbacks for JAX evaluation, JVP and VJP + */ + void register_callback_eval(CallbackEval h); + + /** + * @brief Register callback for JAX evaluation + */ + void register_callback_jvp(CallbackJvp h); + + /** + * @brief Register callback for JVP evaluation + */ + void register_callback_vjp(CallbackVjp h); + + /** + * @brief Register callback for VJP evaluation + */ + void register_callbacks(CallbackEval h, CallbackJvp h_jvp, CallbackVjp h_vjp); + + /** + * @brief JAX evaluation primitive function + */ + void cpu_idaklu_eval(void *out_tuple, const void **in); + + /** + * @brief JVP primitive function + */ + void cpu_idaklu_jvp(void *out_tuple, const void **in); + + /** + * @brief VJP primitive function + */ + void cpu_idaklu_vjp(void *out_tuple, const void **in); + + /** + * @brief Get the instance index + */ + std::int64_t get_index() { return index; }; +}; + +/** + * @brief Non-member function to create a new IdakluJax object + */ +IdakluJax *create_idaklu_jax(); + +/** + * @brief (Non-member) encapsulation helper function + */ +template +pybind11::capsule EncapsulateFunction(T* fn); + +/** + * @brief (Non-member) function dictionary + */ +pybind11::dict Registrations(); + +#endif // PYBAMM_IDAKLU_JAX_SOLVER_HPP diff --git a/src/pybammsolvers/idaklu_source/Options.cpp b/src/pybammsolvers/idaklu_source/Options.cpp new file mode 100644 index 0000000..8eb605f --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Options.cpp @@ -0,0 +1,165 @@ +#include "Options.hpp" +#include +#include +#include + + +using namespace std::string_literals; + +SetupOptions::SetupOptions(py::dict &py_opts) + : jacobian(py_opts["jacobian"].cast()), + preconditioner(py_opts["preconditioner"].cast()), + precon_half_bandwidth(py_opts["precon_half_bandwidth"].cast()), + precon_half_bandwidth_keep(py_opts["precon_half_bandwidth_keep"].cast()), + num_threads(py_opts["num_threads"].cast()), + num_solvers(py_opts["num_solvers"].cast()), + linear_solver(py_opts["linear_solver"].cast()), + linsol_max_iterations(py_opts["linsol_max_iterations"].cast()) +{ + if (num_solvers > num_threads) + { + throw std::domain_error( + "Number of solvers must be less than or equal to the number of threads" + ); + } + + // input num_threads is the overall number of threads to use. num_solvers of these + // will be used to run solvers in parallel, leaving num_threads / num_solvers threads + // to be used by each solver. From here on num_threads is the number of threads to be used by each solver + num_threads = static_cast( + std::floor( + static_cast(num_threads) / static_cast(num_solvers) + ) + ); + + using_sparse_matrix = true; + using_banded_matrix = false; + if (jacobian == "sparse") + { + } + else if (jacobian == "banded") { + using_banded_matrix = true; + using_sparse_matrix = false; + } + else if (jacobian == "dense" || jacobian == "none") + { + using_sparse_matrix = false; + } + else if (jacobian == "matrix-free") + { + } + else + { + throw std::domain_error( + "Unknown jacobian type \""s + jacobian + + "\". Should be one of \"sparse\", \"banded\", \"dense\", \"matrix-free\" or \"none\"."s + ); + } + + using_iterative_solver = false; + if (linear_solver == "SUNLinSol_Dense" && (jacobian == "dense" || jacobian == "none")) + { + } + else if (linear_solver == "SUNLinSol_KLU" && jacobian == "sparse") + { + } + else if (linear_solver == "SUNLinSol_cuSolverSp_batchQR" && jacobian == "sparse") + { + } + else if (linear_solver == "SUNLinSol_Band" && jacobian == "banded") + { + } + else if (jacobian == "banded") { + throw std::domain_error( + "Unknown linear solver or incompatible options: " + "jacobian = \"" + jacobian + "\" linear solver = \"" + linear_solver + + "\". For a banded jacobian " + "please use the SUNLinSol_Band linear solver" + ); + } + else if ((linear_solver == "SUNLinSol_SPBCGS" || + linear_solver == "SUNLinSol_SPFGMR" || + linear_solver == "SUNLinSol_SPGMR" || + linear_solver == "SUNLinSol_SPTFQMR") && + (jacobian == "sparse" || jacobian == "matrix-free")) + { + using_iterative_solver = true; + } + else if (jacobian == "sparse") + { + throw std::domain_error( + "Unknown linear solver or incompatible options: " + "jacobian = \"" + jacobian + "\" linear solver = \"" + linear_solver + + "\". For a sparse jacobian " + "please use the SUNLinSol_KLU linear solver" + ); + } + else if (jacobian == "matrix-free") + { + throw std::domain_error( + "Unknown linear solver or incompatible options. " + "jacobian = \"" + jacobian + "\" linear solver = \"" + linear_solver + + "\". For a matrix-free jacobian " + "please use one of the iterative linear solvers: \"SUNLinSol_SPBCGS\", " + "\"SUNLinSol_SPFGMR\", \"SUNLinSol_SPGMR\", or \"SUNLinSol_SPTFQMR\"." + ); + } + else if (jacobian == "none") + { + throw std::domain_error( + "Unknown linear solver or incompatible options: " + "jacobian = \"" + jacobian + "\" linear solver = \"" + linear_solver + + "\". For no jacobian please use the SUNLinSol_Dense solver" + ); + } + else + { + throw std::domain_error( + "Unknown linear solver or incompatible options. " + "jacobian = \"" + jacobian + "\" linear solver = \"" + linear_solver + "\"" + ); + } + + if (using_iterative_solver) + { + if (preconditioner != "none" && preconditioner != "BBDP") + { + throw std::domain_error( + "Unknown preconditioner \""s + preconditioner + + "\", use one of \"BBDP\" or \"none\""s + ); + } + } + else + { + preconditioner = "none"; + } +} + +SolverOptions::SolverOptions(py::dict &py_opts) + : print_stats(py_opts["print_stats"].cast()), + // IDA main solver + max_order_bdf(py_opts["max_order_bdf"].cast()), + max_num_steps(py_opts["max_num_steps"].cast()), + dt_init(RCONST(py_opts["dt_init"].cast())), + dt_max(RCONST(py_opts["dt_max"].cast())), + max_error_test_failures(py_opts["max_error_test_failures"].cast()), + max_nonlinear_iterations(py_opts["max_nonlinear_iterations"].cast()), + max_convergence_failures(py_opts["max_convergence_failures"].cast()), + nonlinear_convergence_coefficient(RCONST(py_opts["nonlinear_convergence_coefficient"].cast())), + nonlinear_convergence_coefficient_ic(RCONST(py_opts["nonlinear_convergence_coefficient_ic"].cast())), + suppress_algebraic_error(py_opts["suppress_algebraic_error"].cast()), + hermite_interpolation(py_opts["hermite_interpolation"].cast()), + // IDA initial conditions calculation + calc_ic(py_opts["calc_ic"].cast()), + init_all_y_ic(py_opts["init_all_y_ic"].cast()), + max_num_steps_ic(py_opts["max_num_steps_ic"].cast()), + max_num_jacobians_ic(py_opts["max_num_jacobians_ic"].cast()), + max_num_iterations_ic(py_opts["max_num_iterations_ic"].cast()), + max_linesearch_backtracks_ic(py_opts["max_linesearch_backtracks_ic"].cast()), + linesearch_off_ic(py_opts["linesearch_off_ic"].cast()), + // IDALS linear solver interface + linear_solution_scaling(py_opts["linear_solution_scaling"].cast()), + epsilon_linear_tolerance(RCONST(py_opts["epsilon_linear_tolerance"].cast())), + increment_factor(RCONST(py_opts["increment_factor"].cast())) +{} diff --git a/src/pybammsolvers/idaklu_source/Options.hpp b/src/pybammsolvers/idaklu_source/Options.hpp new file mode 100644 index 0000000..7418c68 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Options.hpp @@ -0,0 +1,57 @@ +#ifndef PYBAMM_OPTIONS_HPP +#define PYBAMM_OPTIONS_HPP + +#include "common.hpp" + +/** + * @brief SetupOptions passed to the idaklu setup by pybamm + */ +struct SetupOptions { + bool using_sparse_matrix; + bool using_banded_matrix; + bool using_iterative_solver; + std::string jacobian; + std::string preconditioner; // spbcg + int precon_half_bandwidth; + int precon_half_bandwidth_keep; + int num_threads; + int num_solvers; + // IDALS linear solver interface + std::string linear_solver; // klu, lapack, spbcg + int linsol_max_iterations; + explicit SetupOptions(py::dict &py_opts); +}; + +/** + * @brief SolverOptions passed to the idaklu solver by pybamm + */ +struct SolverOptions { + bool print_stats; + // IDA main solver + int max_order_bdf; + int max_num_steps; + double dt_init; + double dt_max; + int max_error_test_failures; + int max_nonlinear_iterations; + int max_convergence_failures; + double nonlinear_convergence_coefficient; + double nonlinear_convergence_coefficient_ic; + sunbooleantype suppress_algebraic_error; + bool hermite_interpolation; + // IDA initial conditions calculation + bool calc_ic; + bool init_all_y_ic; + int max_num_steps_ic; + int max_num_jacobians_ic; + int max_num_iterations_ic; + int max_linesearch_backtracks_ic; + sunbooleantype linesearch_off_ic; + // IDALS linear solver interface + sunbooleantype linear_solution_scaling; + double epsilon_linear_tolerance; + double increment_factor; + explicit SolverOptions(py::dict &py_opts); +}; + +#endif // PYBAMM_OPTIONS_HPP diff --git a/src/pybammsolvers/idaklu_source/Solution.cpp b/src/pybammsolvers/idaklu_source/Solution.cpp new file mode 100644 index 0000000..7b50364 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Solution.cpp @@ -0,0 +1 @@ +#include "Solution.hpp" diff --git a/src/pybammsolvers/idaklu_source/Solution.hpp b/src/pybammsolvers/idaklu_source/Solution.hpp new file mode 100644 index 0000000..8227bb9 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/Solution.hpp @@ -0,0 +1,39 @@ +#ifndef PYBAMM_IDAKLU_SOLUTION_HPP +#define PYBAMM_IDAKLU_SOLUTION_HPP + +#include "common.hpp" + +/** + * @brief Solution class + */ +class Solution +{ +public: + /** + * @brief Default Constructor + */ + Solution() = default; + + /** + * @brief Constructor + */ + Solution(int &retval, np_array &t_np, np_array &y_np, np_array &yp_np, np_array &yS_np, np_array &ypS_np, np_array &y_term_np) + : flag(retval), t(t_np), y(y_np), yp(yp_np), yS(yS_np), ypS(ypS_np), y_term(y_term_np) + { + } + + /** + * @brief Default copy from another Solution + */ + Solution(const Solution &solution) = default; + + int flag; + np_array t; + np_array y; + np_array yp; + np_array yS; + np_array ypS; + np_array y_term; +}; + +#endif // PYBAMM_IDAKLU_COMMON_HPP diff --git a/src/pybammsolvers/idaklu_source/SolutionData.cpp b/src/pybammsolvers/idaklu_source/SolutionData.cpp new file mode 100644 index 0000000..bc48c64 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/SolutionData.cpp @@ -0,0 +1,99 @@ +#include "SolutionData.hpp" + +Solution SolutionData::generate_solution() { + py::capsule free_t_when_done( + t_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array t_ret = np_array( + number_of_timesteps, + &t_return[0], + free_t_when_done + ); + + py::capsule free_y_when_done( + y_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array y_ret = np_array( + number_of_timesteps * length_of_return_vector, + &y_return[0], + free_y_when_done + ); + + py::capsule free_yp_when_done( + yp_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array yp_ret = np_array( + (save_hermite ? 1 : 0) * number_of_timesteps * length_of_return_vector, + &yp_return[0], + free_yp_when_done + ); + + py::capsule free_yS_when_done( + yS_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array yS_ret = np_array( + std::vector { + arg_sens0, + arg_sens1, + arg_sens2 + }, + &yS_return[0], + free_yS_when_done + ); + + py::capsule free_ypS_when_done( + ypS_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array ypS_ret = np_array( + std::vector { + (save_hermite ? 1 : 0) * arg_sens0, + arg_sens1, + arg_sens2 + }, + &ypS_return[0], + free_ypS_when_done + ); + + // Final state slice, yterm + py::capsule free_yterm_when_done( + yterm_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array y_term = np_array( + length_of_final_sv_slice, + &yterm_return[0], + free_yterm_when_done + ); + + // Store the solution + return Solution(flag, t_ret, y_ret, yp_ret, yS_ret, ypS_ret, y_term); +} diff --git a/src/pybammsolvers/idaklu_source/SolutionData.hpp b/src/pybammsolvers/idaklu_source/SolutionData.hpp new file mode 100644 index 0000000..81ca7f5 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/SolutionData.hpp @@ -0,0 +1,82 @@ +#ifndef PYBAMM_IDAKLU_SOLUTION_DATA_HPP +#define PYBAMM_IDAKLU_SOLUTION_DATA_HPP + + +#include "common.hpp" +#include "Solution.hpp" + +/** + * @brief SolutionData class. Contains all the data needed to create a Solution + */ +class SolutionData +{ + public: + /** + * @brief Default constructor + */ + SolutionData() = default; + + /** + * @brief constructor using fields + */ + SolutionData( + int flag, + int number_of_timesteps, + int length_of_return_vector, + int arg_sens0, + int arg_sens1, + int arg_sens2, + int length_of_final_sv_slice, + bool save_hermite, + realtype *t_return, + realtype *y_return, + realtype *yp_return, + realtype *yS_return, + realtype *ypS_return, + realtype *yterm_return): + flag(flag), + number_of_timesteps(number_of_timesteps), + length_of_return_vector(length_of_return_vector), + arg_sens0(arg_sens0), + arg_sens1(arg_sens1), + arg_sens2(arg_sens2), + length_of_final_sv_slice(length_of_final_sv_slice), + save_hermite(save_hermite), + t_return(t_return), + y_return(y_return), + yp_return(yp_return), + yS_return(yS_return), + ypS_return(ypS_return), + yterm_return(yterm_return) + {} + + + /** + * @brief Default copy from another SolutionData + */ + SolutionData(const SolutionData &solution_data) = default; + + /** + * @brief Create a solution object from this data + */ + Solution generate_solution(); + +private: + + int flag; + int number_of_timesteps; + int length_of_return_vector; + int arg_sens0; + int arg_sens1; + int arg_sens2; + int length_of_final_sv_slice; + bool save_hermite; + realtype *t_return; + realtype *y_return; + realtype *yp_return; + realtype *yS_return; + realtype *ypS_return; + realtype *yterm_return; +}; + +#endif // PYBAMM_IDAKLU_SOLUTION_DATA_HPP diff --git a/src/pybammsolvers/idaklu_source/common.cpp b/src/pybammsolvers/idaklu_source/common.cpp new file mode 100644 index 0000000..161c14f --- /dev/null +++ b/src/pybammsolvers/idaklu_source/common.cpp @@ -0,0 +1,19 @@ +#include "common.hpp" + +std::vector numpy2realtype(const np_array& input_np) { + std::vector output(input_np.request().size); + + auto const inputData = input_np.unchecked<1>(); + for (int i = 0; i < output.size(); i++) { + output[i] = inputData[i]; + } + + return output; +} + + + +std::vector makeSortedUnique(const np_array& input_np) { + const auto input_vec = numpy2realtype(input_np); + return makeSortedUnique(input_vec.begin(), input_vec.end()); +} diff --git a/src/pybammsolvers/idaklu_source/common.hpp b/src/pybammsolvers/idaklu_source/common.hpp new file mode 100644 index 0000000..9067208 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/common.hpp @@ -0,0 +1,170 @@ +#ifndef PYBAMM_IDAKLU_COMMON_HPP +#define PYBAMM_IDAKLU_COMMON_HPP + +#include + +#include /* prototypes for IDAS fcts., consts. */ +#include /* access to IDABBDPRE preconditioner */ + +#include /* access to serial N_Vector */ +#include /* access to openmp N_Vector */ +#include /* defs. of SUNRabs, SUNRexp, etc. */ +#include /* defs. of SUNRabs, SUNRexp, etc. */ +#include /* defs. of realtype, sunindextype */ + +#if SUNDIALS_VERSION_MAJOR >= 6 + #include +#endif + +#include /* access to KLU linear solver */ +#include /* access to dense linear solver */ +#include /* access to dense linear solver */ +#include /* access to spbcgs iterative linear solver */ +#include +#include +#include + +#include /* access to sparse SUNMatrix */ +#include /* access to dense SUNMatrix */ + +#include +#include + +namespace py = pybind11; +// note: we rely on c_style ordering for numpy arrays so don't change this! +using np_array = py::array_t; +using np_array_realtype = py::array_t; +using np_array_int = py::array_t; + +/** + * Utility function to convert compressed-sparse-column (CSC) to/from + * compressed-sparse-row (CSR) matrix representation. Conversion is symmetric / + * invertible using this function. + * @brief Utility function to convert to/from CSC/CSR matrix representations. + * @param f Data vector containing the sparse matrix elements + * @param c Index pointer to column starts + * @param r Array of row indices + * @param nf New data vector that will contain the transformed sparse matrix + * @param nc New array of column indices + * @param nr New index pointer to row starts + */ +template +void csc_csr(const realtype f[], const T1 c[], const T1 r[], realtype nf[], T2 nc[], T2 nr[], int N, int cols) { + std::vector nn(cols+1); + std::vector rr(N); + for (int i=0; i + */ +std::vector numpy2realtype(const np_array& input_np); + +/** + * @brief Utility function to compute the set difference of two vectors + */ +template +std::vector setDiff(const T1 a_begin, const T1 a_end, const T2 b_begin, const T2 b_end) { + std::vector result; + if (std::distance(a_begin, a_end) > 0) { + std::set_difference(a_begin, a_end, b_begin, b_end, std::back_inserter(result)); + } + return result; +} + +/** + * @brief Utility function to make a sorted and unique vector + */ +template +std::vector makeSortedUnique(const T input_begin, const T input_end) { + std::unordered_set uniqueSet(input_begin, input_end); // Remove duplicates + std::vector uniqueVector(uniqueSet.begin(), uniqueSet.end()); // Convert to vector + std::sort(uniqueVector.begin(), uniqueVector.end()); // Sort the vector + return uniqueVector; +} + +std::vector makeSortedUnique(const np_array& input_np); + +#ifdef NDEBUG +#define DEBUG_VECTOR(vector) +#define DEBUG_VECTORn(vector, N) +#define DEBUG_v(v, N) +#define DEBUG(x) +#define DEBUG_n(x) +#define ASSERT(x) +#else + +#define DEBUG_VECTORn(vector, L) {\ + auto M = N_VGetLength(vector); \ + auto N = (M < L) ? M : L; \ + std::cout << #vector << "[" << N << " of " << M << "] = ["; \ + auto array_ptr = N_VGetArrayPointer(vector); \ + for (int i = 0; i < N; i++) { \ + std::cout << array_ptr[i]; \ + if (i < N-1) { \ + std::cout << ", "; \ + } \ + } \ + std::cout << "]" << std::endl; } + +#define DEBUG_VECTOR(vector) {\ + std::cout << #vector << " = ["; \ + auto array_ptr = N_VGetArrayPointer(vector); \ + auto N = N_VGetLength(vector); \ + for (int i = 0; i < N; i++) { \ + std::cout << array_ptr[i]; \ + if (i < N-1) { \ + std::cout << ", "; \ + } \ + } \ + std::cout << "]" << std::endl; } + +#define DEBUG_v(v, N) {\ + std::cout << #v << "[n=" << N << "] = ["; \ + for (int i = 0; i < N; i++) { \ + std::cout << v[i]; \ + if (i < N-1) { \ + std::cout << ", "; \ + } \ + } \ + std::cout << "]" << std::endl; } + +#define DEBUG(x) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " " << x << std::endl; \ + } + +#define DEBUG_n(x) { \ + std::cerr << __FILE__ << ":" << __LINE__ << "," << #x << " = " << x << std::endl; \ + } + +#define ASSERT(x) { \ + if (!(x)) { \ + std::cerr << __FILE__ << ":" << __LINE__ << " Assertion failed: " << #x << std::endl; \ + throw std::runtime_error("Assertion failed: " #x); \ + } \ + } + +#endif + +#endif // PYBAMM_IDAKLU_COMMON_HPP diff --git a/src/pybammsolvers/idaklu_source/idaklu_solver.hpp b/src/pybammsolvers/idaklu_source/idaklu_solver.hpp new file mode 100644 index 0000000..dcc1e4f --- /dev/null +++ b/src/pybammsolvers/idaklu_source/idaklu_solver.hpp @@ -0,0 +1,246 @@ +#ifndef PYBAMM_CREATE_IDAKLU_SOLVER_HPP +#define PYBAMM_CREATE_IDAKLU_SOLVER_HPP + +#include "IDAKLUSolverOpenMP_solvers.hpp" +#include "IDAKLUSolverGroup.hpp" +#include +#include + +/** + * Creates a concrete solver given a linear solver, as specified in + * options_cpp.linear_solver. + * @brief Create a concrete solver given a linear solver + */ +template +IDAKLUSolver *create_idaklu_solver( + std::unique_ptr functions, + int number_of_parameters, + const np_array_int &jac_times_cjmass_colptrs, + const np_array_int &jac_times_cjmass_rowvals, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const int number_of_events, + np_array rhs_alg_id, + np_array atol_np, + double rel_tol, + int inputs_length, + SolverOptions solver_opts, + SetupOptions setup_opts +) { + + IDAKLUSolver *idakluSolver = nullptr; + + // Instantiate solver class + if (setup_opts.linear_solver == "SUNLinSol_Dense") + { + DEBUG("\tsetting SUNLinSol_Dense linear solver"); + idakluSolver = new IDAKLUSolverOpenMP_Dense( + atol_np, + rel_tol, + rhs_alg_id, + number_of_parameters, + number_of_events, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + std::move(functions), + setup_opts, + solver_opts + ); + } + else if (setup_opts.linear_solver == "SUNLinSol_KLU") + { + DEBUG("\tsetting SUNLinSol_KLU linear solver"); + idakluSolver = new IDAKLUSolverOpenMP_KLU( + atol_np, + rel_tol, + rhs_alg_id, + number_of_parameters, + number_of_events, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + std::move(functions), + setup_opts, + solver_opts + ); + } + else if (setup_opts.linear_solver == "SUNLinSol_Band") + { + DEBUG("\tsetting SUNLinSol_Band linear solver"); + idakluSolver = new IDAKLUSolverOpenMP_Band( + atol_np, + rel_tol, + rhs_alg_id, + number_of_parameters, + number_of_events, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + std::move(functions), + setup_opts, + solver_opts + ); + } + else if (setup_opts.linear_solver == "SUNLinSol_SPBCGS") + { + DEBUG("\tsetting SUNLinSol_SPBCGS_linear solver"); + idakluSolver = new IDAKLUSolverOpenMP_SPBCGS( + atol_np, + rel_tol, + rhs_alg_id, + number_of_parameters, + number_of_events, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + std::move(functions), + setup_opts, + solver_opts + ); + } + else if (setup_opts.linear_solver == "SUNLinSol_SPFGMR") + { + DEBUG("\tsetting SUNLinSol_SPFGMR_linear solver"); + idakluSolver = new IDAKLUSolverOpenMP_SPFGMR( + atol_np, + rel_tol, + rhs_alg_id, + number_of_parameters, + number_of_events, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + std::move(functions), + setup_opts, + solver_opts + ); + } + else if (setup_opts.linear_solver == "SUNLinSol_SPGMR") + { + DEBUG("\tsetting SUNLinSol_SPGMR solver"); + idakluSolver = new IDAKLUSolverOpenMP_SPGMR( + atol_np, + rel_tol, + rhs_alg_id, + number_of_parameters, + number_of_events, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + std::move(functions), + setup_opts, + solver_opts + ); + } + else if (setup_opts.linear_solver == "SUNLinSol_SPTFQMR") + { + DEBUG("\tsetting SUNLinSol_SPGMR solver"); + idakluSolver = new IDAKLUSolverOpenMP_SPTFQMR( + atol_np, + rel_tol, + rhs_alg_id, + number_of_parameters, + number_of_events, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + std::move(functions), + setup_opts, + solver_opts + ); + } + + if (idakluSolver == nullptr) { + throw std::invalid_argument("Unsupported solver requested"); + } + + return idakluSolver; +} + +/** + * @brief Create a group of solvers using create_idaklu_solver + */ +template +IDAKLUSolverGroup *create_idaklu_solver_group( + int number_of_states, + int number_of_parameters, + const typename ExprSet::BaseFunctionType &rhs_alg, + const typename ExprSet::BaseFunctionType &jac_times_cjmass, + const np_array_int &jac_times_cjmass_colptrs, + const np_array_int &jac_times_cjmass_rowvals, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const typename ExprSet::BaseFunctionType &jac_action, + const typename ExprSet::BaseFunctionType &mass_action, + const typename ExprSet::BaseFunctionType &sens, + const typename ExprSet::BaseFunctionType &events, + const int number_of_events, + np_array rhs_alg_id, + np_array atol_np, + double rel_tol, + int inputs_length, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + py::dict py_opts +) { + auto setup_opts = SetupOptions(py_opts); + auto solver_opts = SolverOptions(py_opts); + + + std::vector> solvers; + for (int i = 0; i < setup_opts.num_solvers; i++) { + // Note: we can't copy an ExprSet as it contains raw pointers to the functions + // So we create it in the loop + auto functions = std::make_unique( + rhs_alg, + jac_times_cjmass, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals, + jac_times_cjmass_colptrs, + inputs_length, + jac_action, + mass_action, + sens, + events, + number_of_states, + number_of_events, + number_of_parameters, + var_fcns, + dvar_dy_fcns, + dvar_dp_fcns, + setup_opts + ); + solvers.emplace_back( + std::unique_ptr( + create_idaklu_solver( + std::move(functions), + number_of_parameters, + jac_times_cjmass_colptrs, + jac_times_cjmass_rowvals, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + number_of_events, + rhs_alg_id, + atol_np, + rel_tol, + inputs_length, + solver_opts, + setup_opts + ) + ) + ); + } + + return new IDAKLUSolverGroup(std::move(solvers), number_of_states, number_of_parameters); +} + + + +#endif // PYBAMM_CREATE_IDAKLU_SOLVER_HPP diff --git a/src/pybammsolvers/idaklu_source/observe.cpp b/src/pybammsolvers/idaklu_source/observe.cpp new file mode 100644 index 0000000..8f1d90e --- /dev/null +++ b/src/pybammsolvers/idaklu_source/observe.cpp @@ -0,0 +1,343 @@ +#include "observe.hpp" + +int _setup_len_spatial(const std::vector& shape) { + // Calculate the product of all dimensions except the last (spatial dimensions) + int size_spatial = 1; + for (size_t i = 0; i < shape.size() - 1; ++i) { + size_spatial *= shape[i]; + } + + if (size_spatial == 0 || shape.back() == 0) { + throw std::invalid_argument("output array must have at least one element"); + } + + return size_spatial; +} + +// Coupled observe and Hermite interpolation of variables +class HermiteInterpolator { +public: + HermiteInterpolator(const py::detail::unchecked_reference& t, + const py::detail::unchecked_reference& y, + const py::detail::unchecked_reference& yp) + : t(t), y(y), yp(yp) {} + + void compute_knots(const size_t j, vector& c, vector& d) const { + // Called at the start of each interval + const realtype h_full = t(j + 1) - t(j); + const realtype inv_h = 1.0 / h_full; + const realtype inv_h2 = inv_h * inv_h; + const realtype inv_h3 = inv_h2 * inv_h; + + for (size_t i = 0; i < y.shape(0); ++i) { + realtype y_ij = y(i, j); + realtype yp_ij = yp(i, j); + realtype y_ijp1 = y(i, j + 1); + realtype yp_ijp1 = yp(i, j + 1); + + c[i] = 3.0 * (y_ijp1 - y_ij) * inv_h2 - (2.0 * yp_ij + yp_ijp1) * inv_h; + d[i] = 2.0 * (y_ij - y_ijp1) * inv_h3 + (yp_ij + yp_ijp1) * inv_h2; + } + } + + void interpolate(vector& entries, + realtype t_interp, + const size_t j, + vector& c, + vector& d) const { + // Must be called after compute_knots + const realtype h = t_interp - t(j); + const realtype h2 = h * h; + const realtype h3 = h2 * h; + + for (size_t i = 0; i < entries.size(); ++i) { + realtype y_ij = y(i, j); + realtype yp_ij = yp(i, j); + entries[i] = y_ij + yp_ij * h + c[i] * h2 + d[i] * h3; + } + } + +private: + const py::detail::unchecked_reference& t; + const py::detail::unchecked_reference& y; + const py::detail::unchecked_reference& yp; +}; + +class TimeSeriesInterpolator { +public: + TimeSeriesInterpolator(const np_array_realtype& _t_interp, + const vector& _ts_data, + const vector& _ys_data, + const vector& _yps_data, + const vector& _inputs, + const vector>& _funcs, + realtype* _entries, + const int _size_spatial) + : t_interp_np(_t_interp), ts_data_np(_ts_data), ys_data_np(_ys_data), + yps_data_np(_yps_data), inputs_np(_inputs), funcs(_funcs), + entries(_entries), size_spatial(_size_spatial) {} + + void process() { + auto t_interp = t_interp_np.unchecked<1>(); + ssize_t i_interp = 0; + int i_entries = 0; + const ssize_t N_interp = t_interp.size(); + + // Main processing within bounds + process_within_bounds(i_interp, i_entries, t_interp, N_interp); + + // Extrapolation for remaining points + if (i_interp < N_interp) { + extrapolate_remaining(i_interp, i_entries, t_interp, N_interp); + } + } + + void process_within_bounds( + ssize_t& i_interp, + int& i_entries, + const py::detail::unchecked_reference& t_interp, + const ssize_t N_interp + ) { + for (size_t i = 0; i < ts_data_np.size(); i++) { + const auto& t_data = ts_data_np[i].unchecked<1>(); + const realtype t_data_final = t_data(t_data.size() - 1); + realtype t_interp_next = t_interp(i_interp); + // Continue if the next interpolation point is beyond the final data point + if (t_interp_next > t_data_final) { + continue; + } + + const auto& y_data = ys_data_np[i].unchecked<2>(); + const auto& yp_data = yps_data_np[i].unchecked<2>(); + const auto input = inputs_np[i].data(); + const auto func = *funcs[i]; + + resize_arrays(y_data.shape(0), funcs[i]); + args[1] = y_buffer.data(); + args[2] = input; + + ssize_t j = 0; + ssize_t j_prev = -1; + const auto itp = HermiteInterpolator(t_data, y_data, yp_data); + while (t_interp_next <= t_data_final) { + for (; j < t_data.size() - 2; ++j) { + if (t_data(j) <= t_interp_next && t_interp_next <= t_data(j + 1)) { + break; + } + } + + if (j != j_prev) { + // Compute c and d for the new interval + itp.compute_knots(j, c, d); + } + + itp.interpolate(y_buffer, t_interp(i_interp), j, c, d); + + args[0] = &t_interp(i_interp); + results[0] = &entries[i_entries]; + func(args.data(), results.data(), iw.data(), w.data(), 0); + + ++i_interp; + if (i_interp == N_interp) { + return; + } + t_interp_next = t_interp(i_interp); + i_entries += size_spatial; + j_prev = j; + } + } + } + + void extrapolate_remaining( + ssize_t& i_interp, + int& i_entries, + const py::detail::unchecked_reference& t_interp, + const ssize_t N_interp + ) { + const auto& t_data = ts_data_np.back().unchecked<1>(); + const auto& y_data = ys_data_np.back().unchecked<2>(); + const auto& yp_data = yps_data_np.back().unchecked<2>(); + const auto input = inputs_np.back().data(); + const auto func = *funcs.back(); + const ssize_t j = t_data.size() - 2; + + resize_arrays(y_data.shape(0), funcs.back()); + args[1] = y_buffer.data(); + args[2] = input; + + const auto itp = HermiteInterpolator(t_data, y_data, yp_data); + itp.compute_knots(j, c, d); + + for (; i_interp < N_interp; ++i_interp) { + const realtype t_interp_next = t_interp(i_interp); + itp.interpolate(y_buffer, t_interp_next, j, c, d); + + args[0] = &t_interp_next; + results[0] = &entries[i_entries]; + func(args.data(), results.data(), iw.data(), w.data(), 0); + + i_entries += size_spatial; + } + } + + void resize_arrays(const int M, std::shared_ptr func) { + args.resize(func->sz_arg()); + results.resize(func->sz_res()); + iw.resize(func->sz_iw()); + w.resize(func->sz_w()); + if (y_buffer.size() < M) { + y_buffer.resize(M); + c.resize(M); + d.resize(M); + } + } + +private: + const np_array_realtype& t_interp_np; + const vector& ts_data_np; + const vector& ys_data_np; + const vector& yps_data_np; + const vector& inputs_np; + const vector>& funcs; + realtype* entries; + const int size_spatial; + vector c; + vector d; + vector y_buffer; + vector args; + vector results; + vector iw; + vector w; +}; + +// Observe the raw data +class TimeSeriesProcessor { +public: + TimeSeriesProcessor(const vector& _ts, + const vector& _ys, + const vector& _inputs, + const vector>& _funcs, + realtype* _entries, + const bool _is_f_contiguous, + const int _size_spatial) + : ts(_ts), ys(_ys), inputs(_inputs), funcs(_funcs), + entries(_entries), is_f_contiguous(_is_f_contiguous), size_spatial(_size_spatial) {} + + void process() { + int i_entries = 0; + for (size_t i = 0; i < ts.size(); i++) { + const auto& t = ts[i].unchecked<1>(); + const auto& y = ys[i].unchecked<2>(); + const auto input = inputs[i].data(); + const auto func = *funcs[i]; + + resize_arrays(y.shape(0), funcs[i]); + args[2] = input; + + for (size_t j = 0; j < t.size(); j++) { + const realtype t_val = t(j); + const realtype* y_val = is_f_contiguous ? &y(0, j) : copy_to_buffer(y_buffer, y, j); + + args[0] = &t_val; + args[1] = y_val; + results[0] = &entries[i_entries]; + + func(args.data(), results.data(), iw.data(), w.data(), 0); + + i_entries += size_spatial; + } + } + } + +private: + const realtype* copy_to_buffer( + vector& entries, + const py::detail::unchecked_reference& y, + size_t j) { + for (size_t i = 0; i < entries.size(); ++i) { + entries[i] = y(i, j); + } + + return entries.data(); + } + + void resize_arrays(const int M, std::shared_ptr func) { + args.resize(func->sz_arg()); + results.resize(func->sz_res()); + iw.resize(func->sz_iw()); + w.resize(func->sz_w()); + if (!is_f_contiguous && y_buffer.size() < M) { + y_buffer.resize(M); + } + } + + const vector& ts; + const vector& ys; + const vector& inputs; + const vector>& funcs; + realtype* entries; + const bool is_f_contiguous; + int size_spatial; + vector y_buffer; + vector args; + vector results; + vector iw; + vector w; +}; + +const np_array_realtype observe_hermite_interp( + const np_array_realtype& t_interp_np, + const vector& ts_np, + const vector& ys_np, + const vector& yps_np, + const vector& inputs_np, + const vector& strings, + const vector& shape +) { + const int size_spatial = _setup_len_spatial(shape); + const auto& funcs = setup_casadi_funcs(strings); + py::array_t out_array(shape); + auto entries = out_array.mutable_data(); + + TimeSeriesInterpolator(t_interp_np, ts_np, ys_np, yps_np, inputs_np, funcs, entries, size_spatial).process(); + + return out_array; +} + +const np_array_realtype observe( + const vector& ts_np, + const vector& ys_np, + const vector& inputs_np, + const vector& strings, + const bool is_f_contiguous, + const vector& shape +) { + const int size_spatial = _setup_len_spatial(shape); + const auto& funcs = setup_casadi_funcs(strings); + py::array_t out_array(shape); + auto entries = out_array.mutable_data(); + + TimeSeriesProcessor(ts_np, ys_np, inputs_np, funcs, entries, is_f_contiguous, size_spatial).process(); + + return out_array; +} + +const vector> setup_casadi_funcs(const vector& strings) { + std::unordered_map> function_cache; + vector> funcs(strings.size()); + + for (size_t i = 0; i < strings.size(); ++i) { + const std::string& str = strings[i]; + + // Check if function is already in the local cache + if (function_cache.find(str) == function_cache.end()) { + // If not in the cache, create a new casadi::Function::deserialize and store it + function_cache[str] = std::make_shared(casadi::Function::deserialize(str)); + } + + // Retrieve the function from the cache as a shared pointer + funcs[i] = function_cache[str]; + } + + return funcs; +} diff --git a/src/pybammsolvers/idaklu_source/observe.hpp b/src/pybammsolvers/idaklu_source/observe.hpp new file mode 100644 index 0000000..52a0cfd --- /dev/null +++ b/src/pybammsolvers/idaklu_source/observe.hpp @@ -0,0 +1,47 @@ +#ifndef PYBAMM_CREATE_OBSERVE_HPP +#define PYBAMM_CREATE_OBSERVE_HPP + +#include +#include +#include +#include "common.hpp" +#include +#include +using std::vector; + +#if defined(_MSC_VER) + #include + typedef SSIZE_T ssize_t; +#endif + +/** + * @brief Observe and Hermite interpolate ND variables + */ +const np_array_realtype observe_hermite_interp( + const np_array_realtype& t_interp, + const vector& ts, + const vector& ys, + const vector& yps, + const vector& inputs, + const vector& strings, + const vector& shape +); + + +/** + * @brief Observe ND variables + */ +const np_array_realtype observe( + const vector& ts_np, + const vector& ys_np, + const vector& inputs_np, + const vector& strings, + const bool is_f_contiguous, + const vector& shape +); + +const vector> setup_casadi_funcs(const vector& strings); + +int _setup_len_spatial(const vector& shape); + +#endif // PYBAMM_CREATE_OBSERVE_HPP diff --git a/src/pybammsolvers/idaklu_source/sundials_functions.hpp b/src/pybammsolvers/idaklu_source/sundials_functions.hpp new file mode 100644 index 0000000..c4024bc --- /dev/null +++ b/src/pybammsolvers/idaklu_source/sundials_functions.hpp @@ -0,0 +1,36 @@ +#ifndef PYBAMM_SUNDIALS_FUNCTIONS_HPP +#define PYBAMM_SUNDIALS_FUNCTIONS_HPP + +#include "common.hpp" + +template +void axpy(int n, T alpha, const T* x, T* y) { + if (!x || !y) return; + for (int i=0; i + +#define NV_DATA NV_DATA_OMP // Serial: NV_DATA_S + +template +int residual_eval(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr, void *user_data) +{ + DEBUG("residual_eval"); + ExpressionSet *p_python_functions = + static_cast *>(user_data); + + DEBUG_VECTORn(yy, 100); + DEBUG_VECTORn(yp, 100); + + p_python_functions->rhs_alg->m_arg[0] = &tres; + p_python_functions->rhs_alg->m_arg[1] = NV_DATA(yy); + p_python_functions->rhs_alg->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->rhs_alg->m_res[0] = NV_DATA(rr); + (*p_python_functions->rhs_alg)(); + + DEBUG_VECTORn(rr, 100); + + realtype *tmp = p_python_functions->get_tmp_state_vector(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(yp); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); + + // AXPY: y <- a*x + y + const int ns = p_python_functions->number_of_states; + axpy(ns, -1., tmp, NV_DATA(rr)); + + DEBUG("mass - rhs"); + DEBUG_VECTORn(rr, 100); + + // now rr has rhs_alg(t, y) - mass_matrix * yp + return 0; +} + +// This Gres function computes G(t, y, yp). It loads the vector gval as a +// function of tt, yy, and yp. +// +// Arguments: +// Nlocal – is the local vector length. +// +// tt – is the value of the independent variable. +// +// yy – is the dependent variable. +// +// yp – is the derivative of the dependent variable. +// +// gval – is the output vector. +// +// user_data – is a pointer to user data, the same as the user_data parameter +// passed to IDASetUserData(). +// +// Return value: +// +// An IDABBDLocalFn function type should return 0 to indicate success, 1 for a +// recoverable error, or -1 for a non-recoverable error. +// +// Notes: +// +// This function must assume that all inter-processor communication of data +// needed to calculate gval has already been done, and this data is accessible +// within user_data. +// +// The case where G is mathematically identical to F is allowed. +template +int residual_eval_approx(sunindextype Nlocal, realtype tt, N_Vector yy, + N_Vector yp, N_Vector gval, void *user_data) +{ + DEBUG("residual_eval_approx"); + + // Just use true residual for now + int result = residual_eval(tt, yy, yp, gval, user_data); + return result; +} + +// Purpose This function computes the product Jv of the DAE system Jacobian J +// (or an approximation to it) and a given vector v, where J is defined by Eq. +// (2.6). +// J = ∂F/∂y + cj ∂F/∂y˙ +// Arguments tt is the current value of the independent variable. +// yy is the current value of the dependent variable vector, y(t). +// yp is the current value of ˙y(t). +// rr is the current value of the residual vector F(t, y, y˙). +// v is the vector by which the Jacobian must be multiplied to the right. +// Jv is the computed output vector. +// cj is the scalar in the system Jacobian, proportional to the inverse of +// the step +// size (α in Eq. (2.6) ). +// user data is a pointer to user data, the same as the user data parameter +// passed to +// IDASetUserData. +// tmp1 +// tmp2 are pointers to memory allocated for variables of type N Vector +// which can +// be used by IDALsJacTimesVecFn as temporary storage or work space. +template +int jtimes_eval(realtype tt, N_Vector yy, N_Vector yp, N_Vector rr, + N_Vector v, N_Vector Jv, realtype cj, void *user_data, + N_Vector tmp1, N_Vector tmp2) +{ + DEBUG("jtimes_eval"); + T *p_python_functions = + static_cast(user_data); + + // Jv has ∂F/∂y v + p_python_functions->jac_action->m_arg[0] = &tt; + p_python_functions->jac_action->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_action->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->jac_action->m_arg[3] = NV_DATA(v); + p_python_functions->jac_action->m_res[0] = NV_DATA(Jv); + (*p_python_functions->jac_action)(); + + // tmp has -∂F/∂y˙ v + realtype *tmp = p_python_functions->get_tmp_state_vector(); + p_python_functions->mass_action->m_arg[0] = NV_DATA(v); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); + + // AXPY: y <- a*x + y + // Jv has ∂F/∂y v + cj ∂F/∂y˙ v + const int ns = p_python_functions->number_of_states; + axpy(ns, -cj, tmp, NV_DATA(Jv)); + + DEBUG_VECTORn(Jv, 10); + + return 0; +} + +// Arguments tt is the current value of the independent variable t. +// cj is the scalar in the system Jacobian, proportional to the inverse of the +// step +// size (α in Eq. (2.6) ). +// yy is the current value of the dependent variable vector, y(t). +// yp is the current value of ˙y(t). +// rr is the current value of the residual vector F(t, y, y˙). +// Jac is the output (approximate) Jacobian matrix (of type SUNMatrix), J = +// ∂F/∂y + cj ∂F/∂y˙. +// user data is a pointer to user data, the same as the user data parameter +// passed to +// IDASetUserData. +// tmp1 +// tmp2 +// tmp3 are pointers to memory allocated for variables of type N Vector which +// can +// be used by IDALsJacFn function as temporary storage or work space. +template +int jacobian_eval(realtype tt, realtype cj, N_Vector yy, N_Vector yp, + N_Vector resvec, SUNMatrix JJ, void *user_data, + N_Vector tempv1, N_Vector tempv2, N_Vector tempv3) +{ + DEBUG("jacobian_eval"); + + T *p_python_functions = + static_cast(user_data); + + // create pointer to jac data, column pointers, and row values + realtype *jac_data; + if (p_python_functions->setup_opts.using_sparse_matrix) + { + jac_data = SUNSparseMatrix_Data(JJ); + } + else if (p_python_functions->setup_opts.using_banded_matrix) { + jac_data = p_python_functions->get_tmp_sparse_jacobian_data(); + } + else + { + jac_data = SUNDenseMatrix_Data(JJ); + } + + DEBUG_VECTORn(yy, 100); + + // args are t, y, cj, put result in jacobian data matrix + p_python_functions->jac_times_cjmass->m_arg[0] = &tt; + p_python_functions->jac_times_cjmass->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_times_cjmass->m_arg[2] = + p_python_functions->inputs.data(); + p_python_functions->jac_times_cjmass->m_arg[3] = &cj; + p_python_functions->jac_times_cjmass->m_res[0] = jac_data; + (*p_python_functions->jac_times_cjmass)(); + + DEBUG("jac_times_cjmass [" << sizeof(jac_data) << "]"); + DEBUG("t = " << tt); + DEBUG_VECTORn(yy, 100); + DEBUG("inputs = " << p_python_functions->inputs); + DEBUG("cj = " << cj); + DEBUG_v(jac_data, 100); + + if (p_python_functions->setup_opts.using_banded_matrix) + { + // copy data from temporary matrix to the banded matrix + auto jac_colptrs = p_python_functions->jac_times_cjmass_colptrs.data(); + auto jac_rowvals = p_python_functions->jac_times_cjmass_rowvals.data(); + int ncols = p_python_functions->number_of_states; + for (int col_ij = 0; col_ij < ncols; col_ij++) { + realtype *banded_col = SM_COLUMN_B(JJ, col_ij); + for (auto data_i = jac_colptrs[col_ij]; data_i < jac_colptrs[col_ij+1]; data_i++) { + auto row_ij = jac_rowvals[data_i]; + const realtype value_ij = jac_data[data_i]; + DEBUG("(" << row_ij << ", " << col_ij << ") = " << value_ij); + SM_COLUMN_ELEMENT_B(banded_col, row_ij, col_ij) = value_ij; + } + } + } + else if (p_python_functions->setup_opts.using_sparse_matrix) + { + if (SUNSparseMatrix_SparseType(JJ) == CSC_MAT) + { + sunindextype *jac_colptrs = SUNSparseMatrix_IndexPointers(JJ); + sunindextype *jac_rowvals = SUNSparseMatrix_IndexValues(JJ); + // row vals and col ptrs + const int n_row_vals = p_python_functions->jac_times_cjmass_rowvals.size(); + auto p_jac_times_cjmass_rowvals = + p_python_functions->jac_times_cjmass_rowvals.data(); + + // just copy across row vals (do I need to do this every time?) + // (or just in the setup?) + for (int i = 0; i < n_row_vals; i++) + { + jac_rowvals[i] = p_jac_times_cjmass_rowvals[i]; + } + + const int n_col_ptrs = p_python_functions->jac_times_cjmass_colptrs.size(); + auto p_jac_times_cjmass_colptrs = + p_python_functions->jac_times_cjmass_colptrs.data(); + + // just copy across col ptrs (do I need to do this every time?) + for (int i = 0; i < n_col_ptrs; i++) + { + jac_colptrs[i] = p_jac_times_cjmass_colptrs[i]; + } + } else if (SUNSparseMatrix_SparseType(JJ) == CSR_MAT) { + // make a copy so that we can overwrite jac_data as CSR + std::vector newjac(&jac_data[0], &jac_data[SUNSparseMatrix_NNZ(JJ)]); + sunindextype *jac_ptrs = SUNSparseMatrix_IndexPointers(JJ); + sunindextype *jac_vals = SUNSparseMatrix_IndexValues(JJ); + + // convert CSC format to CSR + csc_csr< + std::remove_pointer_tjac_times_cjmass_rowvals.data())>, + std::remove_pointer_t + >( + newjac.data(), + p_python_functions->jac_times_cjmass_rowvals.data(), + p_python_functions->jac_times_cjmass_colptrs.data(), + jac_data, + jac_ptrs, + jac_vals, + SUNSparseMatrix_NNZ(JJ), + SUNSparseMatrix_NP(JJ) + ); + } else + throw std::runtime_error("Unknown matrix format detected (Expected CSC or CSR)"); + } + + return (0); +} + +template +int events_eval(realtype t, N_Vector yy, N_Vector yp, realtype *events_ptr, + void *user_data) +{ + DEBUG("events_eval"); + T *p_python_functions = + static_cast(user_data); + + // args are t, y, put result in events_ptr + p_python_functions->events->m_arg[0] = &t; + p_python_functions->events->m_arg[1] = NV_DATA(yy); + p_python_functions->events->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->events->m_res[0] = events_ptr; + (*p_python_functions->events)(); + + return (0); +} + +// This function computes the sensitivity residual for all sensitivity +// equations. It must compute the vectors +// (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i ) and store them in resvalS[i]. +// Ns is the number of sensitivities. +// t is the current value of the independent variable. +// yy is the current value of the state vector, y(t). +// yp is the current value of ẏ(t). +// resval contains the current value F of the original DAE residual. +// yS contains the current values of the sensitivities s i . +// ypS contains the current values of the sensitivity derivatives ṡ i . +// resvalS contains the output sensitivity residual vectors. +// Memory allocation for resvalS is handled within idas. +// user data is a pointer to user data. +// tmp1, tmp2, tmp3 are N Vectors of length N which can be used as +// temporary storage. +// +// Return value An IDASensResFn should return 0 if successful, +// a positive value if a recoverable error +// occurred (in which case idas will attempt to correct), +// or a negative value if it failed unrecoverably (in which case the integration +// is halted and IDA SRES FAIL is returned) +template +int sensitivities_eval(int Ns, realtype t, N_Vector yy, N_Vector yp, + N_Vector resval, N_Vector *yS, N_Vector *ypS, + N_Vector *resvalS, void *user_data, N_Vector tmp1, + N_Vector tmp2, N_Vector tmp3) +{ + + DEBUG("sensitivities_eval"); + T *p_python_functions = + static_cast(user_data); + + const int np = p_python_functions->number_of_parameters; + + // args are t, y put result in rr + p_python_functions->sens->m_arg[0] = &t; + p_python_functions->sens->m_arg[1] = NV_DATA(yy); + p_python_functions->sens->m_arg[2] = p_python_functions->inputs.data(); + for (int i = 0; i < np; i++) + { + p_python_functions->sens->m_res[i] = NV_DATA(resvalS[i]); + } + // resvalsS now has (∂F/∂p i ) + (*p_python_functions->sens)(); + + for (int i = 0; i < np; i++) + { + // put (∂F/∂y)s i (t) in tmp + realtype *tmp = p_python_functions->get_tmp_state_vector(); + p_python_functions->jac_action->m_arg[0] = &t; + p_python_functions->jac_action->m_arg[1] = NV_DATA(yy); + p_python_functions->jac_action->m_arg[2] = p_python_functions->inputs.data(); + p_python_functions->jac_action->m_arg[3] = NV_DATA(yS[i]); + p_python_functions->jac_action->m_res[0] = tmp; + (*p_python_functions->jac_action)(); + + const int ns = p_python_functions->number_of_states; + axpy(ns, 1., tmp, NV_DATA(resvalS[i])); + + // put -(∂F/∂ ẏ) ṡ i (t) in tmp2 + p_python_functions->mass_action->m_arg[0] = NV_DATA(ypS[i]); + p_python_functions->mass_action->m_res[0] = tmp; + (*p_python_functions->mass_action)(); + + // (∂F/∂y)s i (t)+(∂F/∂ ẏ) ṡ i (t)+(∂F/∂p i ) + // AXPY: y <- a*x + y + axpy(ns, -1., tmp, NV_DATA(resvalS[i])); + } + + return 0; +} diff --git a/src/pybammsolvers/idaklu_source/sundials_legacy_wrapper.hpp b/src/pybammsolvers/idaklu_source/sundials_legacy_wrapper.hpp new file mode 100644 index 0000000..f4855b1 --- /dev/null +++ b/src/pybammsolvers/idaklu_source/sundials_legacy_wrapper.hpp @@ -0,0 +1,94 @@ + +#if SUNDIALS_VERSION_MAJOR < 6 + + #define SUN_PREC_NONE PREC_NONE + #define SUN_PREC_LEFT PREC_LEFT + + // Compatibility layer - wrap older sundials functions in new-style calls + void SUNContext_Create(void *comm, SUNContext *ctx) + { + // Function not available + return; + } + + int SUNContext_Free(SUNContext *ctx) + { + // Function not available + return; + } + + void* IDACreate(SUNContext sunctx) + { + return IDACreate(); + } + + N_Vector N_VNew_Serial(sunindextype vec_length, SUNContext sunctx) + { + return N_VNew_Serial(vec_length); + } + + N_Vector N_VNew_OpenMP(sunindextype vec_length, SUNContext sunctx) + { + return N_VNew_OpenMP(vec_length); + } + + N_Vector N_VNew_Cuda(sunindextype vec_length, SUNContext sunctx) + { + return N_VNew_Cuda(vec_length); + } + + SUNMatrix SUNSparseMatrix(sunindextype M, sunindextype N, sunindextype NNZ, int sparsetype, SUNContext sunctx) + { + return SUNMatrix SUNSparseMatrix(M, N, NNZ, sparsetype); + } + + SUNMatrix SUNMatrix_cuSparse_NewCSR(int M, int N, int NNZ, cusparseHandle_t cusp, SUNContext sunctx) + { + return SUNMatrix_cuSparse_NewCSR(M, N, NNZ, cusp); + } + + SUNMatrix SUNBandMatrix(sunindextype N, sunindextype mu, sunindextype ml, SUNContext sunctx) + { + return SUNMatrix SUNBandMatrix(N, mu, ml); + } + + SUNMatrix SUNDenseMatrix(sunindextype M, sunindextype N, SUNContext sunctx) + { + return SUNDenseMatrix(M, N, sunctx); + } + + SUNLinearSolver SUNLinSol_Dense(N_Vector y, SUNMatrix A, SUNContext sunctx) + { + return SUNLinSol_Dense(y, A, sunctx); + } + + SUNLinearSolver SUNLinSol_KLU(N_Vector y, SUNMatrix A, SUNContext sunctx) + { + return SUNLinSol_KLU(y, A, sunctx); + } + + SUNLinearSolver SUNLinSol_Band(N_Vector y, SUNMatrix A, SUNContext sunctx) + { + return SUNLinSol_Band(y, A, sunctx); + } + + SUNLinearSolver SUNLinSol_SPBCGS(N_Vector y, int pretype, int maxl, SUNContext sunctx) + { + return SUNLinSol_SPBCGS(y, pretype, maxl); + } + + SUNLinearSolver SUNLinSol_SPFGMR(N_Vector y, int pretype, int maxl, SUNContext sunctx) + { + return SUNLinSol_SPFGMR(y, pretype, maxl); + } + + SUNLinearSolver SUNLinSol_SPGMR(N_Vector y, int pretype, int maxl, SUNContext sunctx) + { + return SUNLinSol_SPGMR(y, pretype, maxl); + } + + SUNLinearSolver SUNLinSol_SPTFQMR(N_Vector y, int pretype, int maxl, SUNContext sunctx) + { + return SUNLinSol_SPTFQMR(y, pretype, maxl); + } +#endif diff --git a/src/pyidaklu/version.py b/src/pybammsolvers/version.py similarity index 100% rename from src/pyidaklu/version.py rename to src/pybammsolvers/version.py diff --git a/src/pyidaklu/__init__.py b/src/pyidaklu/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 5b99e6b..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,3 +0,0 @@ - -def tests_run(): - assert True diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000..1b34d08 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,5 @@ +import pybammsolvers + + +def test_import(): + assert hasattr(pybammsolvers, "idaklu") diff --git a/vcpkg-configuration.json b/vcpkg-configuration.json new file mode 100644 index 0000000..8d5c64a --- /dev/null +++ b/vcpkg-configuration.json @@ -0,0 +1,20 @@ +{ + "default-registry": { + "kind": "builtin", + "baseline": "" + }, + "registries": [ + { + "kind": "git", + "repository": "https://github.com/pybamm-team/sundials-vcpkg-registry.git", + "baseline": "13d432fcf5da8591bb6cb2d46be9d6acf39cd02b", + "packages": ["sundials"] + }, + { + "kind": "git", + "repository": "https://github.com/pybamm-team/casadi-vcpkg-registry.git", + "baseline": "1cb93f2fb71be26c874db724940ef8e604ee558e", + "packages": ["casadi"] + } + ] +} diff --git a/vcpkg.json b/vcpkg.json new file mode 100644 index 0000000..250c9c1 --- /dev/null +++ b/vcpkg.json @@ -0,0 +1,12 @@ +{ + "name": "pybammsolvers", + "version-string": "0.0.1", + "dependencies": [ + "casadi", + { + "name": "sundials", + "default-features": false, + "features": ["klu"] + } + ] +}