Skip to content

Commit

Permalink
Big Bang
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Feb 23, 2024
1 parent 8217b2a commit 00b94f4
Show file tree
Hide file tree
Showing 42 changed files with 4,969 additions and 2 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/pre_commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Code linting

on:
pull_request:

push:
branches:
- main

jobs:
pre-commit:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.12'
- uses: pre-commit/action@v3.0.0
160 changes: 160 additions & 0 deletions .gitignore copy
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/ambv/black
rev: 23.12.1
hooks:
- id: black-jupyter
language_version: python3

- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
146 changes: 144 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,144 @@
# exponax
the all new exponax in multiple dimensions with multiple channels
# Exponax

A suite of simple solvers for 1d PDEs on periodic domains based on exponential
time differencing algorithms, built on top of
[JAX](https://github.com/google/jax). **Efficient**, **Elegant**,
**Vectorizable**, and **Differentiable**.

### Quickstart - 1d Kuramoto-Sivashinsky equation

```python
import jax
import exponax as ex
import matplotlib.pyplot as plt

ks_stepper = ex.KuramotoSivashinskyConservative(
num_spatial_dims=1, domain_extent=100.0,
num_points=200, dt=0.1,
)

u_0 = ex.RandomTruncatedFourierSeries(
num_spatial_dims=1, cutoff=5
)(num_points=200, key=jax.random.PRNGKey(0))

trajectory = ex.rollout(ks_stepper, 500, include_init=True)(u_0)

plt.imshow(trajectory[:, 0, :].T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2, origin="lower")
plt.xlabel("Time"); plt.ylabel("Space"); plt.show()
```

![](ks_rollout.png)

See also the *examples* folder for more. It is best to start with
`simple_advection_example.ipynb` to get familiar with the ideoms of the package,
especially if not too familiar with JAX. Then, continue with the
`solver_showcase.ipynb`. To see the solvers in action to solve a supervised
learning problem, see `learning_burgers_autoregressive_neural_operator.ipynb`. A
tutorial notebook that requires the differentiability of the solvers is in the
works.

### Features

Using JAX as the computational backend gives:

1. **Backend agnotistic code** - run on CPU, GPU, or TPU, in both single and double
precision.
2. **Automatic differentiation** over the timesteppers - compute gradients of
solutions with respect to initial conditions, parameters, etc.
3. Also helpful for **tight integration with Deep Learning** since each
timestepper is also just an [Equinox](https://github.com/patrick-kidger/equinox) Module.
4. **Automatic Vectorization** using `jax.vmap` (or `equinox.filter_vmap`)
allowing to advance multiple states in time or instantiate multiple solvers at a time that operate efficiently in batch.

Exponax strives to be lightweight and without custom types; there is no `grid` or `state` object. Everything is based on `jax.numpy` arrays.

### Background

Exponax supports the efficient solution of 1d (semi-linear) partial differential equations on periodic domains. Those are PDEs of the form

$$ \partial u/ \partial t = Lu + N(u) $$

where $L$ is a linear differential operator and $N$ is a nonlinear differential
operator. The linear part can be exactly solved using a (matrix) exponential,
and the nonlinear part is approximated using Runge-Kutta methods of various
orders. These methods have been known in various disciplines in science for a
long time and have been unified for a first time by [Cox & Matthews](https://doi.org/10.1006/jcph.2002.6995) [1]. In particular, this package uses the complex contour integral method of [Kassam & Trefethen](https://doi.org/10.1137/S1064827502410633) [2] for numerical stability. The package is restricted to original first, second, third and fourth order method. Since the package of [1] many extensions have been developed. A recent study by [Montanelli & Bootland](https://doi.org/10.1016/j.matcom.2020.06.008) [3] showed that the original *ETDRK4* method is still one of the most efficient methods for these types of PDEs.

### Built-In solvers

This package comes with the following solvers:

* Linear PDEs:
* Advection equation
* Diffusion equation
* Advection-Diffusion equation
* Dispersion equation
* Hyper-Diffusion equation
* General linear equation containing zeroth, first, second, third, and fourth order derivatives
* Nonlinear PDEs:
* Burgers equation
* Kuramoto-Sivashinsky equation
* Korteweg-de Vries equation

Other equations can easily be implemented by subclassing from the `BaseStepper`
module.

### Other functionality

Next to the timesteppers operating on JAX array states, it also comes with:

* Initial Conditions:
* Random sine waves
* Diffused Noise
* Random Discontinuities
* Gaussian Random Fields
* Utilities:
* Mesh creation
* Rollout functions
* Spectral derivatives
* Initial condition set creation
* Poisson solver
* Modification to make solvers take an additional forcing argument
* Modification to make solvers perform substeps for more accurate simulation

### Similar projects and motivation for this package

This package is greatly inspired by the [chebfun](https://www.chebfun.org/)
package in *MATLAB*, in particular the
[`spinX`](https://www.chebfun.org/docs/guide/guide19.html) module within it. It
has been used extensively as a data generator in early works for supervised
physics-informed ML, e.g., the
[DeepHiddenPhysics](https://github.com/maziarraissi/DeepHPMs/tree/7b579dbdcf5be4969ebefd32e65f709a8b20ec44/Matlab)
and [Fourier Neural
Operators](https://github.com/neuraloperator/neuraloperator/tree/af93f781d5e013f8ba5c52baa547f2ada304ffb0/data_generation)
(the links show where in their public repos they use the `spinX` module). The
approach of pre-sampling the solvers, writing out the trajectories, and then
using them for supervised training worked for these problems, but of course
limits to purely supervised problem. Modern research ideas like correcting
coarse solvers (see for instance the [Solver-in-the-Loop
paper](https://arxiv.org/abs/2007.00016) or the [ML-accelerated CFD
paper](https://arxiv.org/abs/2102.01010)) requires the coarse solvers to be
[differentiable](https://physicsbaseddeeplearning.org/diffphys.html). Some ideas
of diverted chain training also requires the fine solver to be differentiable!
Even for applications without differentiable solvers, we still have the
**interface problem** with legacy solvers (like the MATLAB ones). Hence, we
cannot easily query them "on-the-fly" for sth like active learning tasks, nor do
they run efficiently on hardward accelerators (GPUs, TPUs, etc.). Additionally,
they were not designed with batch execution (in the sense of vectorized
application) in mind which we get more or less for free by `jax.vmap`. With the
reproducible randomness of `JAX` we might not even have to ever write out a
dataset and can re-create it in seconds!

This package took much inspiration from the
[FourierFlows.jl](https://github.com/FourierFlows/FourierFlows.jl) in the
*Julia* ecosystem, especially for checking the implementation of the contout
integral method of [2] and how to handle (de)aliasing.


### References

[1] Cox, Steven M., and Paul C. Matthews. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455.

[2] Kassam, A.K. and Trefethen, L.N., 2005. Fourth-order time-stepping for stiff PDEs. SIAM Journal on Scientific Computing, 26(4), pp.1214-1233.

[3] Montanelli, Hadrien, and Niall Bootland. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327.
Loading

0 comments on commit 00b94f4

Please sign in to comment.