diff --git a/.github/workflows/pre_commit.yml b/.github/workflows/pre_commit.yml new file mode 100644 index 0000000..37f41e9 --- /dev/null +++ b/.github/workflows/pre_commit.yml @@ -0,0 +1,19 @@ +name: Code linting + +on: + pull_request: + + push: + branches: + - main + +jobs: + pre-commit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.12' + - uses: pre-commit/action@v3.0.0 \ No newline at end of file diff --git a/.gitignore copy b/.gitignore copy new file mode 100644 index 0000000..68bc17f --- /dev/null +++ b/.gitignore copy @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c56e09d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/ambv/black + rev: 23.12.1 + hooks: + - id: black-jupyter + language_version: python3 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 diff --git a/README.md b/README.md index 1e32121..2e6e767 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,144 @@ -# exponax -the all new exponax in multiple dimensions with multiple channels +# Exponax + +A suite of simple solvers for 1d PDEs on periodic domains based on exponential +time differencing algorithms, built on top of +[JAX](https://github.com/google/jax). **Efficient**, **Elegant**, +**Vectorizable**, and **Differentiable**. + +### Quickstart - 1d Kuramoto-Sivashinsky equation + +```python +import jax +import exponax as ex +import matplotlib.pyplot as plt + +ks_stepper = ex.KuramotoSivashinskyConservative( + num_spatial_dims=1, domain_extent=100.0, + num_points=200, dt=0.1, +) + +u_0 = ex.RandomTruncatedFourierSeries( + num_spatial_dims=1, cutoff=5 +)(num_points=200, key=jax.random.PRNGKey(0)) + +trajectory = ex.rollout(ks_stepper, 500, include_init=True)(u_0) + +plt.imshow(trajectory[:, 0, :].T, aspect='auto', cmap='RdBu', vmin=-2, vmax=2, origin="lower") +plt.xlabel("Time"); plt.ylabel("Space"); plt.show() +``` + +![](ks_rollout.png) + +See also the *examples* folder for more. It is best to start with +`simple_advection_example.ipynb` to get familiar with the ideoms of the package, +especially if not too familiar with JAX. Then, continue with the +`solver_showcase.ipynb`. To see the solvers in action to solve a supervised +learning problem, see `learning_burgers_autoregressive_neural_operator.ipynb`. A +tutorial notebook that requires the differentiability of the solvers is in the +works. + +### Features + +Using JAX as the computational backend gives: + +1. **Backend agnotistic code** - run on CPU, GPU, or TPU, in both single and double + precision. +2. **Automatic differentiation** over the timesteppers - compute gradients of + solutions with respect to initial conditions, parameters, etc. +3. Also helpful for **tight integration with Deep Learning** since each + timestepper is also just an [Equinox](https://github.com/patrick-kidger/equinox) Module. +4. **Automatic Vectorization** using `jax.vmap` (or `equinox.filter_vmap`) + allowing to advance multiple states in time or instantiate multiple solvers at a time that operate efficiently in batch. + +Exponax strives to be lightweight and without custom types; there is no `grid` or `state` object. Everything is based on `jax.numpy` arrays. + +### Background + +Exponax supports the efficient solution of 1d (semi-linear) partial differential equations on periodic domains. Those are PDEs of the form + +$$ \partial u/ \partial t = Lu + N(u) $$ + +where $L$ is a linear differential operator and $N$ is a nonlinear differential +operator. The linear part can be exactly solved using a (matrix) exponential, +and the nonlinear part is approximated using Runge-Kutta methods of various +orders. These methods have been known in various disciplines in science for a +long time and have been unified for a first time by [Cox & Matthews](https://doi.org/10.1006/jcph.2002.6995) [1]. In particular, this package uses the complex contour integral method of [Kassam & Trefethen](https://doi.org/10.1137/S1064827502410633) [2] for numerical stability. The package is restricted to original first, second, third and fourth order method. Since the package of [1] many extensions have been developed. A recent study by [Montanelli & Bootland](https://doi.org/10.1016/j.matcom.2020.06.008) [3] showed that the original *ETDRK4* method is still one of the most efficient methods for these types of PDEs. + +### Built-In solvers + +This package comes with the following solvers: + +* Linear PDEs: + * Advection equation + * Diffusion equation + * Advection-Diffusion equation + * Dispersion equation + * Hyper-Diffusion equation + * General linear equation containing zeroth, first, second, third, and fourth order derivatives +* Nonlinear PDEs: + * Burgers equation + * Kuramoto-Sivashinsky equation + * Korteweg-de Vries equation + +Other equations can easily be implemented by subclassing from the `BaseStepper` +module. + +### Other functionality + +Next to the timesteppers operating on JAX array states, it also comes with: + +* Initial Conditions: + * Random sine waves + * Diffused Noise + * Random Discontinuities + * Gaussian Random Fields +* Utilities: + * Mesh creation + * Rollout functions + * Spectral derivatives + * Initial condition set creation +* Poisson solver +* Modification to make solvers take an additional forcing argument +* Modification to make solvers perform substeps for more accurate simulation + +### Similar projects and motivation for this package + +This package is greatly inspired by the [chebfun](https://www.chebfun.org/) +package in *MATLAB*, in particular the +[`spinX`](https://www.chebfun.org/docs/guide/guide19.html) module within it. It +has been used extensively as a data generator in early works for supervised +physics-informed ML, e.g., the +[DeepHiddenPhysics](https://github.com/maziarraissi/DeepHPMs/tree/7b579dbdcf5be4969ebefd32e65f709a8b20ec44/Matlab) +and [Fourier Neural +Operators](https://github.com/neuraloperator/neuraloperator/tree/af93f781d5e013f8ba5c52baa547f2ada304ffb0/data_generation) +(the links show where in their public repos they use the `spinX` module). The +approach of pre-sampling the solvers, writing out the trajectories, and then +using them for supervised training worked for these problems, but of course +limits to purely supervised problem. Modern research ideas like correcting +coarse solvers (see for instance the [Solver-in-the-Loop +paper](https://arxiv.org/abs/2007.00016) or the [ML-accelerated CFD +paper](https://arxiv.org/abs/2102.01010)) requires the coarse solvers to be +[differentiable](https://physicsbaseddeeplearning.org/diffphys.html). Some ideas +of diverted chain training also requires the fine solver to be differentiable! +Even for applications without differentiable solvers, we still have the +**interface problem** with legacy solvers (like the MATLAB ones). Hence, we +cannot easily query them "on-the-fly" for sth like active learning tasks, nor do +they run efficiently on hardward accelerators (GPUs, TPUs, etc.). Additionally, +they were not designed with batch execution (in the sense of vectorized +application) in mind which we get more or less for free by `jax.vmap`. With the +reproducible randomness of `JAX` we might not even have to ever write out a +dataset and can re-create it in seconds! + +This package took much inspiration from the +[FourierFlows.jl](https://github.com/FourierFlows/FourierFlows.jl) in the +*Julia* ecosystem, especially for checking the implementation of the contout +integral method of [2] and how to handle (de)aliasing. + + +### References + +[1] Cox, Steven M., and Paul C. Matthews. "Exponential time differencing for stiff systems." Journal of Computational Physics 176.2 (2002): 430-455. + +[2] Kassam, A.K. and Trefethen, L.N., 2005. Fourth-order time-stepping for stiff PDEs. SIAM Journal on Scientific Computing, 26(4), pp.1214-1233. + +[3] Montanelli, Hadrien, and Niall Bootland. "Solving periodic semilinear stiff PDEs in 1D, 2D and 3D with exponential integrators." Mathematics and Computers in Simulation 178 (2020): 307-327. \ No newline at end of file diff --git a/exponax/__init__.py b/exponax/__init__.py new file mode 100644 index 0000000..f34baf8 --- /dev/null +++ b/exponax/__init__.py @@ -0,0 +1,94 @@ +from .forced_stepper import ForcedStepper +from .initial_conditions import ( + MultiChannelIC, + RandomMultiChannelICGenerator, + RandomTruncatedFourierSeries, + DiffusedNoise, + GaussianRandomField, +) +from .poisson import Poisson +from .repeated_stepper import RepeatedStepper +from .sample_stepper import ( + Advection, + Diffusion, + AdvectionDiffusion, + Dispersion, + HyperDiffusion, + GeneralLinearStepper, + Burgers, + KuramotoSivashinsky, + KuramotoSivashinskyConservative, + Nikolaevskiy, + NikolaevskiyConservative, + GeneralConvectionStepper, + GeneralGradientNormStepper, + NavierStokesVorticity2d, + KolmogorovFlowVorticity2d, + SwiftHohenberg, + GrayScott, + KortevegDeVries, + FisherKPP, + AllenCahn, + CahnHilliard, + BelousovZhabotinsky, +) +from .normalized_stepper import ( + NormalizedLinearStepper, + NormalizedConvectionStepper, + NormalizedGradientNormStepper, + normalize_coefficients, + denormalize_coefficients, + normalize_convection_scale, + denormalize_convection_scale, + normalize_gradient_norm_scale, + denormalize_gradient_norm_scale, +) +from .utils import ( + get_grid, + get_animation, + get_grouped_animation, + rollout, + repeat, + stack_sub_trajectories, + build_ic_set, +) +from .spectral import ( + derivative, +) + +__all__ = [ + "ForcedStepper", + "SineWaves", + "RandomSineWaves", + "DiffusedNoise", + "RandomDiffusedNoise", + "Poisson", + "RepeatedStepper", + "Advection", + "Advection1d", + "Advection2d", + "Advection3d", + "Diffusion", + "Diffusion1d", + "Diffusion2d", + "Diffusion3d", + "AdvectionDiffusion", + "Dispersion", + "HyperDiffusion", + "Burgers", + "Burgers1d", + "Burgers2d", + "Burgers3d", + "KuramotoSivashinsky", + "KuramotoSivashinsky1d", + "KuramotoSivashinsky2d", + "KuramotoSivashinsky3d", + "NavierStokesVorticity2d", + "get_grid", + "get_animation", + "get_grouped_animation", + "rollout", + "repeat", + "stack_sub_trajectories", + "build_ic_set", +] diff --git a/exponax/base_stepper.py b/exponax/base_stepper.py new file mode 100644 index 0000000..d9f2551 --- /dev/null +++ b/exponax/base_stepper.py @@ -0,0 +1,191 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Array, Float, Complex + +from .exponential_integrators import BaseETDRK, ETDRK0, ETDRK1, ETDRK2, ETDRK3, ETDRK4 +from .spectral import ( + build_derivative_operator, + space_indices, + spatial_shape, + wavenumber_shape, +) + +from .nonlinear_functions import BaseNonlinearFun + + +class BaseStepper(eqx.Module): + num_spatial_dims: int + domain_extent: float + num_points: int + num_channels: int + dt: float + dx: float + + _integrator: BaseETDRK + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + num_channels: int, + order: int, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.num_points = num_points + self.dt = dt + self.num_channels = num_channels + + # Uses the convention that N does **not** include the right boundary + # point + self.dx = domain_extent / num_points + + derivative_operator = build_derivative_operator( + num_spatial_dims, domain_extent, num_points + ) + + linear_operator = self._build_linear_operator(derivative_operator) + single_channel_shape = (1,) + wavenumber_shape( + self.num_spatial_dims, self.num_points + ) # Same operator for each channel (i.e., we broadcast) + multi_channel_shape = (self.num_channels,) + wavenumber_shape( + self.num_spatial_dims, self.num_points + ) # Different operator for each channel + if linear_operator.shape not in (single_channel_shape, multi_channel_shape): + raise ValueError( + f"Expected linear operator to have shape {single_channel_shape} or {multi_channel_shape}, got {linear_operator.shape}." + ) + nonlinear_fun = self._build_nonlinear_fun(derivative_operator) + + if order == 0: + self._integrator = ETDRK0( + dt, + linear_operator, + ) + elif order == 1: + self._integrator = ETDRK1( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + elif order == 2: + self._integrator = ETDRK2( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + elif order == 3: + self._integrator = ETDRK3( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + elif order == 4: + self._integrator = ETDRK4( + dt, + linear_operator, + nonlinear_fun, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + else: + raise NotImplementedError(f"Order {order} not implemented.") + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "D ... (N//2)+1"]: + """ + Assemble the L operator in Fourier space. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `( D, ..., + N//2+1 )`. The ellipsis are (D-1) axis of size N (**not** of size + N//2+1). + + **Returns:** + - `L`: The linear operator, shape `( D, ..., N//2+1 )`. + """ + raise NotImplementedError("Must be implemented in subclass.") + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> BaseNonlinearFun: + """ + Build the function that evaluates nonlinearity in physical space, + transforms to Fourier space, and evaluates derivatives there. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `( D, ..., N//2+1 )`. + + **Returns:** + - `nonlinear_fun`: A function that evaluates the nonlinearities in + time space, transforms to Fourier space, and evaluates the + derivatives there. Should be a subclass of `BaseNonlinearFun`. + """ + raise NotImplementedError("Must be implemented in subclass.") + + def step(self, u: Float[Array, "C ... N"]) -> Float[Array, "C ... N"]: + """ + Perform one step of the time integration. + + **Arguments:** + - `u`: The state vector, shape `(C, ..., N,)`. + + **Returns:** + - `u_next`: The state vector after one step, shape `(C, ..., N,)`. + """ + u_hat = jnp.fft.rfftn(u, axes=space_indices(self.num_spatial_dims)) + u_next_hat = self.step_fourier(u_hat) + u_next = jnp.fft.irfftn( + u_next_hat, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + return u_next + + def step_fourier( + self, u_hat: Complex[Array, "C ... (N//2)+1"] + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Perform one step of the time integration in Fourier space. Oftentimes, + this is more efficient than `step` since it avoids back and forth + transforms. + + **Arguments:** + - `u_hat`: The (real) Fourier transform of the state vector + + **Returns:** + - `u_next_hat`: The (real) Fourier transform of the state vector + after one step + """ + return self._integrator.step_fourier(u_hat) + + def __call__( + self, + u: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Performs a check + """ + expected_shape = (self.num_channels,) + spatial_shape( + self.num_spatial_dims, self.num_points + ) + if u.shape != expected_shape: + raise ValueError( + f"Expected shape {expected_shape}, got {u.shape}. For batched operation use `jax.vmap` on this function." + ) + return self.step(u) diff --git a/exponax/exponential_integrators.py b/exponax/exponential_integrators.py new file mode 100644 index 0000000..efbb1c9 --- /dev/null +++ b/exponax/exponential_integrators.py @@ -0,0 +1,281 @@ +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float +from typing import Callable + +from .nonlinear_functions import BaseNonlinearFun, ZeroNonlinearFun + +# E can either be 1 (single channel) or num_channels (multi-channel) for either +# the same linear operator for each channel or a different linear operator for +# each channel, respectively. +# +# So far, we do **not** support channel mixing via the linear operator (for +# example if we solved the wave equation or the sine-Gordon equation). + + +class BaseETDRK(eqx.Module): + dt: float + _exp_term: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + ): + self.dt = dt + self._exp_term = jnp.exp(self.dt * linear_operator) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Advance the state in Fourier space. + """ + raise NotImplementedError("Must be implemented by subclass") + + +class ETDRK0(BaseETDRK): + """ + Exactly solve a linear PDE in Fourier space + """ + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + return self._exp_term * u_hat + + +def roots_of_unity(M: int) -> Complex[Array, "M"]: + """ + Return (complex-valued) array with M roots of unity. + """ + # return jnp.exp(1j * jnp.pi * (jnp.arange(1, M+1) - 0.5) / M) + return jnp.exp(2j * jnp.pi * (jnp.arange(1, M + 1) - 0.5) / M) + + +class ETDRK1(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _coef_1: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + return self._exp_term * u_hat + self._coef_1 * self._nonlinear_fun(u_hat) + + +class ETDRK2(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _coef_1: Complex[Array, "E ... (N//2)+1"] + _coef_2: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real + + self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1 - LR) / LR**2, axis=-1).real + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_nonlin_hat = self._nonlinear_fun(u_hat) + u_stage_1_hat = self._exp_term * u_hat + self._coef_1 * u_nonlin_hat + + u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat) + u_next_hat = u_stage_1_hat + self._coef_2 * ( + u_stage_1_nonlin_hat - u_nonlin_hat + ) + return u_next_hat + + +class ETDRK3(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _half_exp_term: Complex[Array, "E ... (N//2)+1"] + _coef_1: Complex[Array, "E ... (N//2)+1"] + _coef_2: Complex[Array, "E ... (N//2)+1"] + _coef_3: Complex[Array, "E ... (N//2)+1"] + _coef_4: Complex[Array, "E ... (N//2)+1"] + _coef_5: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + self._half_exp_term = jnp.exp(0.5 * dt * linear_operator) + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real + + self._coef_2 = dt * jnp.mean((jnp.exp(LR) - 1) / LR, axis=-1).real + + self._coef_3 = ( + dt + * jnp.mean( + (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1 + ).real + ) + + self._coef_4 = ( + dt + * jnp.mean( + (4.0 * (2.0 + LR + jnp.exp(LR) * (-2 + LR))) / (LR**3), axis=-1 + ).real + ) + + self._coef_5 = ( + dt + * jnp.mean( + (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1 + ).real + ) + + def step_fourier( + self, + u_hat: Complex[Array, "E ... (N//2)+1"], + ) -> Complex[Array, "E ... (N//2)+1"]: + u_nonlin_hat = self._nonlinear_fun(u_hat) + u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat + + u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat) + u_stage_2_hat = self._exp_term * u_hat + self._coef_2 * ( + 2 * u_stage_1_nonlin_hat - u_nonlin_hat + ) + + u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat) + + u_next_hat = ( + self._exp_term * u_hat + + self._coef_3 * u_nonlin_hat + + self._coef_4 * u_stage_1_nonlin_hat + + self._coef_5 * u_stage_2_nonlin_hat + ) + + return u_next_hat + + +class ETDRK4(BaseETDRK): + _nonlinear_fun: BaseNonlinearFun + _half_exp_term: Complex[Array, "E ... (N//2)+1"] + _coef_1: Complex[Array, "E ... (N//2)+1"] + _coef_2: Complex[Array, "E ... (N//2)+1"] + _coef_3: Complex[Array, "E ... (N//2)+1"] + _coef_4: Complex[Array, "E ... (N//2)+1"] + _coef_5: Complex[Array, "E ... (N//2)+1"] + _coef_6: Complex[Array, "E ... (N//2)+1"] + + def __init__( + self, + dt: float, + linear_operator: Complex[Array, "E ... (N//2)+1"], + nonlinear_fun: BaseNonlinearFun, + *, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + super().__init__(dt, linear_operator) + self._nonlinear_fun = nonlinear_fun + self._half_exp_term = jnp.exp(0.5 * dt * linear_operator) + + LR = ( + circle_radius * roots_of_unity(n_circle_points) + + linear_operator[..., jnp.newaxis] * dt + ) + + self._coef_1 = dt * jnp.mean((jnp.exp(LR / 2) - 1) / LR, axis=-1).real + + self._coef_2 = self._coef_1 + self._coef_3 = self._coef_1 + + self._coef_4 = ( + dt + * jnp.mean( + (-4 - LR + jnp.exp(LR) * (4 - 3 * LR + LR**2)) / (LR**3), axis=-1 + ).real + ) + + self._coef_5 = ( + dt * jnp.mean((2 + LR + jnp.exp(LR) * (-2 + LR)) / (LR**3), axis=-1).real + ) + + self._coef_6 = ( + dt + * jnp.mean( + (-4 - 3 * LR - LR**2 + jnp.exp(LR) * (4 - LR)) / (LR**3), axis=-1 + ).real + ) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_nonlin_hat = self._nonlinear_fun(u_hat) + u_stage_1_hat = self._half_exp_term * u_hat + self._coef_1 * u_nonlin_hat + + u_stage_1_nonlin_hat = self._nonlinear_fun(u_stage_1_hat) + u_stage_2_hat = ( + self._half_exp_term * u_hat + self._coef_2 * u_stage_1_nonlin_hat + ) + + u_stage_2_nonlin_hat = self._nonlinear_fun(u_stage_2_hat) + u_stage_3_hat = self._half_exp_term * u_stage_1_hat + self._coef_3 * ( + 2 * u_stage_2_nonlin_hat - u_nonlin_hat + ) + + u_stage_3_nonlin_hat = self._nonlinear_fun(u_stage_3_hat) + + u_next_hat = ( + self._exp_term * u_hat + + self._coef_4 * u_nonlin_hat + + self._coef_5 * 2 * (u_stage_1_nonlin_hat + u_stage_2_nonlin_hat) + + self._coef_6 * u_stage_3_nonlin_hat + ) + + return u_next_hat diff --git a/exponax/forced_stepper.py b/exponax/forced_stepper.py new file mode 100644 index 0000000..ff2d421 --- /dev/null +++ b/exponax/forced_stepper.py @@ -0,0 +1,103 @@ +from typing import Any +import equinox as eqx +from .base_stepper import BaseStepper + +from jaxtyping import Array, Float, Complex + + +class ForcedStepper(eqx.Module): + stepper: BaseStepper + + def __init__( + self, + stepper: BaseStepper, + ): + """ + Transform a stepper of signature `(u,) -> u_next` into a stepper of + signature `(u, f) -> u_next` that also accepts a forcing vector `f`. + + Transforms a stepper for a PDE of the form u_t = Lu + N(u) into a stepper + for a PDE of the form u_t = Lu + N(u) + f, where f is a forcing term. For + this, we split by operators + + v_t = f + + u_t = Lv + N(v) + + Since we assume to only have access to the forcing function evaluated at one + time level (but on the same grid as the state), we use a forward Euler + scheme to integrate the first equation. The second equation is integrated + using the original stepper. + + Note: This operator splitting makes the total scheme only first order + accurate in time. It is a quick hack to extend the other sophisticated + transient integrators to forced problems. + + **Arguments**: + - `stepper`: The stepper to be transformed. + """ + self.stepper = stepper + + def step( + self, + u: Float[Array, "C ... N"], + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by one time step given the current state + `u` and the forcing term `f`. + + The forcing term `f` is assumed to be evaluated on the same grid as `u`. + + **Arguments**: + - `u`: The current state. + - `f`: The forcing term. + + **Returns**: + - `u_next`: The state after one time step. + """ + u_with_force = u + self.stepper.dt * f + return self.stepper.step(u_with_force) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + f_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Step the PDE forward in time by one time step given the current state + `u_hat` in Fourier space and the forcing term `f_hat` in Fourier space. + + The forcing term `f_hat` is assumed to be evaluated on the same grid as + `u_hat`. + + **Arguments**: + - `u_hat`: The current state in Fourier space. + - `f_hat`: The forcing term in Fourier space. + + **Returns**: + - `u_next_hat`: The state after one time step in Fourier space. + """ + u_hat_with_force = u_hat + self.stepper.dt * f_hat + return self.stepper.step_fourier(u_hat_with_force) + + def __call__( + self, + u: Float[Array, "C ... N"], + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by one time step given the current state + `u` and the forcing term `f`. + + The forcing term `f` is assumed to be evaluated on the same grid as `u`. + + **Arguments**: + - `u`: The current state. + - `f`: The forcing term. + + **Returns**: + - `u_next`: The state after one time step. + """ + + return self.step(u, f) diff --git a/exponax/initial_conditions.py b/exponax/initial_conditions.py new file mode 100644 index 0000000..371e296 --- /dev/null +++ b/exponax/initial_conditions.py @@ -0,0 +1,472 @@ +import jax.numpy as jnp +import jax.random as jr +from typing import List +import equinox as eqx +from jaxtyping import Complex, Array, Float, PRNGKeyArray + +from abc import ABC, abstractmethod +from typing import Optional + +from .sample_stepper import Diffusion +from .spectral import ( + build_scaled_wavenumbers, + spatial_shape, + wavenumber_shape, + low_pass_filter_mask, + space_indices, + build_scaling_array, +) +from .utils import get_grid + +### --- Base classes --- ### + + +class BaseIC(eqx.Module, ABC): + + @abstractmethod + def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: + """ + Evaluate the initial condition. + + **Arguments**: + - `x`: The grid points. + + **Returns**: + - `u`: The initial condition evaluated at the grid points. + """ + pass + + +class BaseRandomICGenerator(eqx.Module): + num_spatial_dims: int + domain_extent: float + indexing: str = "ij" + + def gen_ic_fun(self, num_points: int, *, key: PRNGKeyArray) -> BaseIC: + """ + Generate an initial condition function. + + **Arguments**: + - `num_points`: The number of grid points in each dimension. + - `key`: A jax random key. + + **Returns**: + - `ic`: An initial condition function that can be evaluated at + degree of freedom locations. + """ + raise NotImplementedError( + "This random ic generator cannot represent its initial condition as a function. Directly evaluate it." + ) + + def __call__( + self, + num_points: int, + *, + key: PRNGKeyArray, + ) -> Float[Array, "1 ... N"]: + """ + Generate a random initial condition. + + **Arguments**: + - `num_points`: The number of grid points in each dimension. + - `key`: A jax random key. + - `indexing`: The indexing convention for the grid. + + **Returns**: + - `u`: The initial condition evaluated at the grid points. + """ + ic_fun = self.gen_ic_fun(num_points, key=key) + grid = get_grid( + self.num_spatial_dims, + self.domain_extent, + num_points, + indexing=self.indexing, + ) + return ic_fun(grid) + + +### Utilities to create ICs for multi-channel fields + + +class MultiChannelIC(eqx.Module): + initial_conditions: List[BaseIC] + + def __call__(self, x: Float[Array, "D ... N"]) -> Float[Array, "C ... N"]: + """ + Evaluate the initial condition. + + **Arguments**: + - `x`: The grid points. + + **Returns**: + - `u`: The initial condition evaluated at the grid points. + """ + return jnp.concatenate([ic(x) for ic in self.initial_conditions], axis=0) + + +class RandomMultiChannelICGenerator(eqx.Module): + ic_generators: List[BaseRandomICGenerator] + + def gen_ic_fun(self, num_points: int, *, key: PRNGKeyArray) -> MultiChannelIC: + ic_funs = [ + ic_gen.gen_ic_fun(num_points, key=key) for ic_gen in self.ic_generators + ] + return MultiChannelIC(ic_funs) + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "C ... N"]: + u_list = [ic_gen(num_points, key=key) for ic_gen in self.ic_generators] + return jnp.concatenate(u_list, axis=0) + + +### New version + +# class TruncatedFourierSeries(BaseIC): +# coefficient_array: Complex[Array, "1 ... (N//2)+1"] + +# def __init__( +# self, +# D: int, +# L: float, # unused +# N: int, +# *, +# coefficient_array: Complex[Array, "1 ... N"], +# ): +# super().__init__(D, N) +# self.coefficient_array = coefficient_array + +# def evaluate(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: +# return jnp.fft.irfftn( +# self.coefficient_array, +# s=spatial_shape(self.D, self.N), +# axes=space_indices(self.D), +# ) + + +class RandomTruncatedFourierSeries(BaseRandomICGenerator): + num_spatial_dims: int + domain_extent: float + cutoff: int + amplitude_range: tuple[int, int] + angle_range: tuple[int, int] + offset_range: tuple[int, int] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float = 1.0, + *, + cutoff: int = 10, + amplitude_range: tuple[int, int] = (-1.0, 1.0), + angle_range: tuple[int, int] = (0.0, 2.0 * jnp.pi), + offset_range: tuple[int, int] = (0.0, 0.0), # no offset by default + ): + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + + self.cutoff = cutoff + self.amplitude_range = amplitude_range + self.angle_range = angle_range + self.offset_range = offset_range + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "1 ... N"]: + fourier_noise_shape = (1,) + wavenumber_shape(self.num_spatial_dims, num_points) + amplitude_key, angle_key, offset_key = jr.split(key, 3) + + amplitude = jr.uniform( + amplitude_key, + shape=fourier_noise_shape, + minval=self.amplitude_range[0], + maxval=self.amplitude_range[1], + ) + angle = jr.uniform( + angle_key, + shape=fourier_noise_shape, + minval=self.angle_range[0], + maxval=self.angle_range[1], + ) + + fourier_noise = amplitude * jnp.exp(1j * angle) + + low_pass_filter = low_pass_filter_mask( + self.num_spatial_dims, num_points, cutoff=self.cutoff, axis_separate=True + ) + + fourier_noise = fourier_noise * low_pass_filter + + offset = jr.uniform( + offset_key, + shape=(1,), + minval=self.offset_range[0], + maxval=self.offset_range[1], + )[0] + fourier_noise = ( + fourier_noise.flatten().at[0].set(offset).reshape(fourier_noise_shape) + ) + + fourier_noise = fourier_noise * build_scaling_array( + self.num_spatial_dims, num_points + ) + + u = jnp.fft.irfftn( + fourier_noise, + s=spatial_shape(self.num_spatial_dims, num_points), + axes=space_indices(self.num_spatial_dims), + ) + + return u + + +### --- Legacy Sine Waves (truncated Fourier series) --- ### + +# class SineWaves(BaseIC): +# L: float +# filter_mask: Float[Array, "1 ... (N//2)+1"] +# zero_mean: bool +# key: PRNGKeyArray + + +# def __init__( +# self, +# D: int, +# L: float, +# N: int, +# *, +# cutoff: int, +# zero_mean: bool, +# axis_separate: bool = True, +# key: PRNGKeyArray, +# ): +# super().__init__(D, N) +# self.L = L +# self.filter_mask = low_pass_filter_mask(D, N, cutoff=cutoff, axis_separate=axis_separate) +# self.zero_mean = zero_mean +# self.key = key + +# def evaluate(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: +# noise_shape = (1,) + spatial_shape(self.D, self.N) + +# noise = jr.normal(self.key, shape=noise_shape) +# noise_hat = jnp.fft.rfftn(noise, axes=space_indices(self.D)) +# noise_hat = noise_hat * self.filter_mask + +# noise = jnp.fft.irfftn(noise_hat, s=spatial_shape(self.D, self.N), axes=space_indices(self.D)) + +# if self.zero_mean: +# noise = noise - jnp.mean(noise) + +# return noise + +# class RandomSineWaves(BaseRandomICGenerator): +# D: int +# L: float +# N: int +# cutoff: int +# zero_mean: bool +# axis_separate: bool + +# def __init__( +# self, +# D: int, +# L: float, +# N: int, +# *, +# cutoff: int, +# zero_mean: bool, +# axis_separate: bool = True, +# ): +# """ +# Randomly generated initial condition consisting of a truncated Fourier series. + +# Arguments are drawn from uniform distributions. + +# **Arguments**: +# - `D`: The dimension of the domain. +# - `N`: The number of grid points in each dimension. +# - `L`: The length of the domain. +# - `cutoff`: The cutoff wavenumber. +# - `zero_mean`: Whether to subtract the mean. +# - `axis_separate`: Whether to draw the wavenumber cutoffs for each +# axis separately. +# """ +# self.D = D +# self.N = N +# self.L = L +# self.cutoff = cutoff +# self.zero_mean = zero_mean +# self.axis_separate = axis_separate + +# def __call__(self, key: PRNGKeyArray) -> SineWaves: +# return SineWaves( +# self.D, +# self.L, +# self.N, +# cutoff=self.cutoff, +# zero_mean=self.zero_mean, +# axis_separate=self.axis_separate, +# key=key, +# ) + + +# --- Diffused Noise --- ### + +# class DiffusedNoise(BaseIC): +# L: float +# intensity: float +# zero_mean: bool +# key: PRNGKeyArray + +# def __init__( +# self, +# D: int, +# L: float, +# N: int, +# *, +# intensity: float, +# zero_mean: bool, +# key: PRNGKeyArray, +# ): +# super().__init__(D, N) +# self.L = L +# self.intensity = intensity +# self.zero_mean = zero_mean +# self.key = key + +# def evaluate(self, x: Float[Array, "D ... N"]) -> Float[Array, "1 ... N"]: +# noise_shape = (1,) + spatial_shape(self.D, self.N) +# noise = jr.normal(self.key, shape=noise_shape) + +# diffusion_stepper = Diffusion(self.D, self.L, self.N, 1.0, diffusivity=self.intensity) +# ic = diffusion_stepper(noise) + +# if self.zero_mean: +# ic = ic - jnp.mean(ic) + +# return ic + + +class DiffusedNoise(BaseRandomICGenerator): + num_spatial_dims: int + domain_extent: float + intensity: float + zero_mean: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float = 1.0, + *, + intensity=0.001, + zero_mean: bool = False, + ): + """ + Randomly generated initial condition consisting of a diffused noise field. + + Arguments are drawn from uniform distributions. + + **Arguments**: + - `D`: The dimension of the domain. + - `L`: The length of the domain. + - `N`: The number of grid points in each dimension. + - `intensity`: The diffusivity. + - `zero_mean`: Whether to subtract the mean. + """ + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.intensity = intensity + self.zero_mean = zero_mean + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "1 ... N"]: + noise_shape = (1,) + spatial_shape(self.num_spatial_dims, num_points) + noise = jr.normal(key, shape=noise_shape) + + diffusion_stepper = Diffusion( + self.num_spatial_dims, + self.domain_extent, + num_points, + 1.0, + diffusivity=self.intensity, + ) + ic = diffusion_stepper(noise) + + if self.zero_mean: + ic = ic - jnp.mean(ic) + + return ic + + +### Gausian Random Field ### + + +class GaussianRandomField(BaseRandomICGenerator): + num_spatial_dims: int + domain_extent: float + powerlaw_exponent: float + normalize: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float = 1.0, + *, + powerlaw_exponent: float = 3.0, + normalize: bool = True, + ): + """ + Randomly generated initial condition consisting of a Gaussian random field. + """ + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.powerlaw_exponent = powerlaw_exponent + self.normalize = normalize + + def __call__( + self, num_points: int, *, key: PRNGKeyArray + ) -> Float[Array, "1 ... N"]: + wavenumber_grid = build_scaled_wavenumbers( + self.num_spatial_dims, self.domain_extent, num_points + ) + wavenumer_norm_grid = jnp.linalg.norm(wavenumber_grid, axis=0, keepdims=True) + amplitude = jnp.power(wavenumer_norm_grid, -self.powerlaw_exponent / 2.0) + amplitude = ( + amplitude.flatten().at[0].set(0.0).reshape(wavenumer_norm_grid.shape) + ) + + real_key, imag_key = jr.split(key, 2) + noise = jr.normal( + real_key, + shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points), + ) + 1j * jr.normal( + imag_key, + shape=(1,) + wavenumber_shape(self.num_spatial_dims, num_points), + ) + + noise = noise * amplitude + + ic = jnp.fft.irfftn( + noise, + s=spatial_shape(self.num_spatial_dims, num_points), + axes=space_indices(self.num_spatial_dims), + ) + + if self.normalize: + ic = ic - jnp.mean(ic) + ic = ic / jnp.std(ic) + + return ic + + +### Discontinuities ### + + +class Discontinuities(BaseIC): + pass + + +class RandomDiscontinuities(BaseRandomICGenerator): + pass diff --git a/exponax/nonlinear_functions/__init__.py b/exponax/nonlinear_functions/__init__.py new file mode 100644 index 0000000..28c821b --- /dev/null +++ b/exponax/nonlinear_functions/__init__.py @@ -0,0 +1,14 @@ +from .base import BaseNonlinearFun +from .convection import ConvectionNonlinearFun +from .gradient_norm import GradientNormNonlinearFun +from .polynomial import PolynomialNonlinearFun +from .reaction import ( + GrayScottNonlinearFun, + CahnHilliardNonlinearFun, + BelousovZhabotinskyNonlinearFun, +) +from .vorticity_convection import ( + VorticityConvection2d, + VorticityConvection2dKolmogorov, +) +from .zero import ZeroNonlinearFun diff --git a/exponax/nonlinear_functions/base.py b/exponax/nonlinear_functions/base.py new file mode 100644 index 0000000..cc51496 --- /dev/null +++ b/exponax/nonlinear_functions/base.py @@ -0,0 +1,69 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + wavenumber_shape, + low_pass_filter_mask, +) +from abc import ABC, abstractmethod + + +class BaseNonlinearFun(eqx.Module, ABC): + num_spatial_dims: int + num_points: int + num_channels: int + derivative_operator: Complex[Array, "D ... (N//2)+1"] + dealiasing_mask: Bool[Array, "1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + self.num_spatial_dims = num_spatial_dims + self.num_points = num_points + self.num_channels = num_channels + self.derivative_operator = derivative_operator + + # Can be done because num_points is identical in all spatial dimensions + nyquist_mode = (num_points // 2) + 1 + highest_resolved_mode = nyquist_mode - 1 + start_of_aliased_modes = dealiasing_fraction * highest_resolved_mode + + self.dealiasing_mask = low_pass_filter_mask( + num_spatial_dims, + num_points, + cutoff=start_of_aliased_modes - 1, + ) + + @abstractmethod + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Evaluate all potential nonlinearities "pseudo-spectrally", account for dealiasing. + """ + raise NotImplementedError("Must be implemented by subclass") + + def __call__( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Perform check + """ + expected_shape = (self.num_channels,) + wavenumber_shape( + self.num_spatial_dims, self.num_points + ) + if u_hat.shape != expected_shape: + raise ValueError( + f"Expected shape {expected_shape}, got {u_hat.shape}. For batched operation use `jax.vmap` on this function." + ) + + return self.evaluate(u_hat) diff --git a/exponax/nonlinear_functions/convection.py b/exponax/nonlinear_functions/convection.py new file mode 100644 index 0000000..a4854d8 --- /dev/null +++ b/exponax/nonlinear_functions/convection.py @@ -0,0 +1,68 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + space_indices, + spatial_shape, +) + +from .base import BaseNonlinearFun + + +class ConvectionNonlinearFun(BaseNonlinearFun): + convection_scale: float + zero_mode_fix: bool + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + convection_scale: float = 0.5, + zero_mode_fix: bool = False, + ): + self.convection_scale = convection_scale + self.zero_mode_fix = zero_mode_fix + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def zero_fix( + self, + f: Float[Array, "... N"], + ): + return f - jnp.mean(f) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_outer_product = u[:, None] * u[None, :] + + if self.zero_mode_fix: + # Maybe there is more efficient way + u_outer_product = jax.vmap(self.zero_fix)(u_outer_product) + + u_outer_product_hat = jnp.fft.rfftn( + u_outer_product, axes=space_indices(self.num_spatial_dims) + ) + u_divergence_on_outer_product_hat = jnp.sum( + self.derivative_operator[None, :] * u_outer_product_hat, + axis=1, + ) + # Requires minus to move term to the rhs + return -self.convection_scale * u_divergence_on_outer_product_hat diff --git a/exponax/nonlinear_functions/gradient_norm.py b/exponax/nonlinear_functions/gradient_norm.py new file mode 100644 index 0000000..957dedf --- /dev/null +++ b/exponax/nonlinear_functions/gradient_norm.py @@ -0,0 +1,73 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + space_indices, + spatial_shape, +) + +from .base import BaseNonlinearFun + + +class GradientNormNonlinearFun(BaseNonlinearFun): + scale: float + zero_mode_fix: bool + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + zero_mode_fix: bool = True, + scale: float = 0.5, + ): + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + self.zero_mode_fix = zero_mode_fix + self.scale = scale + + def zero_fix( + self, + f: Float[Array, "... N"], + ): + return f - jnp.mean(f) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_gradient_hat = self.derivative_operator[None, :] * u_hat[:, None] + u_gradient_dealiased_hat = self.dealiasing_mask * u_gradient_hat + u_gradient = jnp.fft.irfftn( + u_gradient_dealiased_hat, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + + # Reduces the axis introduced by the gradient + u_gradient_norm_squared = jnp.sum(u_gradient**2, axis=1) + + if self.zero_mode_fix: + # Maybe there is more efficient way + u_gradient_norm_squared = jax.vmap(self.zero_fix)(u_gradient_norm_squared) + + u_gradient_norm_squared_hat = jnp.fft.rfftn( + u_gradient_norm_squared, axes=space_indices(self.num_spatial_dims) + ) + # if self.zero_mode_fix: + # # Fix the mean mode + # u_gradient_norm_squared_hat = u_gradient_norm_squared_hat.at[..., 0].set( + # u_hat[..., 0] + # ) + + # Requires minus to move term to the rhs + return -self.scale * u_gradient_norm_squared_hat diff --git a/exponax/nonlinear_functions/polynomial.py b/exponax/nonlinear_functions/polynomial.py new file mode 100644 index 0000000..46bd1d5 --- /dev/null +++ b/exponax/nonlinear_functions/polynomial.py @@ -0,0 +1,61 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + space_indices, + spatial_shape, +) + +from .base import BaseNonlinearFun + + +class PolynomialNonlinearFun(BaseNonlinearFun): + """ + Channel-separate evaluation; and no mixed terms. + """ + + coefficients: list[float] # Starting from order 0 + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + coefficients: list[float], + ): + """ + Coefficient list starts from order 0. + """ + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + self.coefficients = coefficients + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = 1.0 + u_nonlin = 0.0 + for coeff in self.coefficients: + u_nonlin += coeff * u_power + u_power = u_power * u + + u_nonlin_hat = jnp.fft.rfftn( + u_nonlin, axes=space_indices(self.num_spatial_dims) + ) + return u_nonlin_hat diff --git a/exponax/nonlinear_functions/reaction.py b/exponax/nonlinear_functions/reaction.py new file mode 100644 index 0000000..ad44a98 --- /dev/null +++ b/exponax/nonlinear_functions/reaction.py @@ -0,0 +1,146 @@ +""" +Nonlinear terms as they are found in reaction-diffusion(-advection) equations. +""" + +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from typing import Callable +from ..spectral import ( + space_indices, + spatial_shape, + build_laplace_operator, +) + +from .base import BaseNonlinearFun + + +class GrayScottNonlinearFun(BaseNonlinearFun): + b: float + d: float + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + b: float, + d: float, + ): + if num_channels != 2: + raise ValueError(f"Expected num_channels = 2, got {num_channels}.") + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + self.b = b + self.d = d + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = jnp.stack( + [ + self.b * (1 - u[0]) - u[0] * u[1] ** 2, + -self.d * u[1] + u[0] * u[1] ** 2, + ] + ) + u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + return u_power_hat + + +class CahnHilliardNonlinearFun(BaseNonlinearFun): + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + if num_channels != 1: + raise ValueError(f"Expected num_channels = 1, got {num_channels}.") + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = u[0] ** 3 + u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + u_power_laplace_hat = ( + build_laplace_operator(self.derivative_operator, order=2) * u_power_hat + ) + return u_power_laplace_hat + + +class BelousovZhabotinskyNonlinearFun(BaseNonlinearFun): + """ + Taken from: https://github.com/chebfun/chebfun/blob/db207bc9f48278ca4def15bf90591bfa44d0801d/spin.m#L73 + """ + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + if num_channels != 3: + raise ValueError(f"Expected num_channels = 3, got {num_channels}.") + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + u_hat_dealiased = self.dealiasing_mask * u_hat + u = jnp.fft.irfftn( + u_hat_dealiased, + s=spatial_shape(self.num_spatial_dims, self.num_points), + axes=space_indices(self.num_spatial_dims), + ) + u_power = jnp.stack( + [ + u[0] + u[1] - u[0] * u[1] - u[0] ** 2, + u[2] - u[1] - u[0] * u[1], + u[0] - u[2], + ] + ) + u_power_hat = jnp.fft.rfftn(u_power, axes=space_indices(self.num_spatial_dims)) + return u_power_hat diff --git a/exponax/nonlinear_functions/vorticity_convection.py b/exponax/nonlinear_functions/vorticity_convection.py new file mode 100644 index 0000000..847f3f6 --- /dev/null +++ b/exponax/nonlinear_functions/vorticity_convection.py @@ -0,0 +1,116 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool +from ..spectral import ( + build_laplace_operator, + build_wavenumbers, + build_scaling_array, +) + +from .base import BaseNonlinearFun + + +class VorticityConvection2d(BaseNonlinearFun): + inv_laplacian: Complex[Array, "1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + if num_spatial_dims != 2: + raise ValueError(f"Expected num_spatial_dims = 2, got {num_spatial_dims}.") + if num_channels != 1: + raise ValueError(f"Expected num_channels = 1, got {num_channels}.") + + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + laplacian = build_laplace_operator(derivative_operator, order=2) + + # Uses the UNCHANGED mean solution to the Poisson equation (hence, the + # mean of the "right-hand side" will be the mean of the solution) + self.inv_laplacian = jnp.where(laplacian == 0, 1.0, 1 / laplacian) + + def evaluate( + self, u_hat: Complex[Array, "1 ... (N//2)+1"] + ) -> Complex[Array, "1 ... (N//2)+1"]: + vorticity_hat = u_hat + stream_function_hat = self.inv_laplacian * vorticity_hat + + u_hat = +self.derivative_operator[1:2] * stream_function_hat + v_hat = -self.derivative_operator[0:1] * stream_function_hat + del_vorticity_del_x_hat = self.derivative_operator[0:1] * vorticity_hat + del_vorticity_del_y_hat = self.derivative_operator[1:2] * vorticity_hat + + u = jnp.fft.irfft2( + u_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) + ) + v = jnp.fft.irfft2( + v_hat * self.dealiasing_mask, s=(self.num_points, self.num_points) + ) + del_vorticity_del_x = jnp.fft.irfft2( + del_vorticity_del_x_hat * self.dealiasing_mask, + s=(self.num_points, self.num_points), + ) + del_vorticity_del_y = jnp.fft.irfft2( + del_vorticity_del_y_hat * self.dealiasing_mask, + s=(self.num_points, self.num_points), + ) + + convection = u * del_vorticity_del_x + v * del_vorticity_del_y + + convection_hat = jnp.fft.rfft2(convection) + + # Do we need another dealiasing mask here? + # convection_hat = self.dealiasing_mask * convection_hat + + # Requires minus to move term to the rhs + return -convection_hat + + +class VorticityConvection2dKolmogorov(VorticityConvection2d): + injection: Complex[Array, "1 ... (N//2)+1"] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + injection_mode: int = 4, + injection_scale: float = 1.0, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float, + ): + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + wavenumbers = build_wavenumbers(num_spatial_dims, num_points) + injection_mask = (wavenumbers[0] == 0) & (wavenumbers[1] == injection_mode) + self.injection = jnp.where( + injection_mask, + injection_scale * build_scaling_array(num_spatial_dims, num_points), + 0.0, + ) + + def evaluate( + self, u_hat: Complex[Array, "1 ... (N//2)+1"] + ) -> Complex[Array, "1 ... (N//2)+1"]: + neg_convection_hat = super().evaluate(u_hat) + return neg_convection_hat + self.injection diff --git a/exponax/nonlinear_functions/zero.py b/exponax/nonlinear_functions/zero.py new file mode 100644 index 0000000..91e6033 --- /dev/null +++ b/exponax/nonlinear_functions/zero.py @@ -0,0 +1,31 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Complex, Array, Float, Bool + +from .base import BaseNonlinearFun + + +class ZeroNonlinearFun(BaseNonlinearFun): + def __init__( + self, + num_spatial_dims: int, + num_points: int, + num_channels: int, + *, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + dealiasing_fraction: float = 1.0, + ): + super().__init__( + num_spatial_dims, + num_points, + num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=dealiasing_fraction, + ) + + def evaluate( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + return jnp.zeros_like(u_hat) diff --git a/exponax/normalized_stepper/__init__.py b/exponax/normalized_stepper/__init__.py new file mode 100644 index 0000000..9f61198 --- /dev/null +++ b/exponax/normalized_stepper/__init__.py @@ -0,0 +1,11 @@ +from .convection import NormalizedConvectionStepper +from .gradient_norm import NormalizedGradientNormStepper +from .linear import NormalizedLinearStepper +from .utils import ( + denormalize_coefficients, + denormalize_convection_scale, + denormalize_gradient_norm_scale, + normalize_coefficients, + normalize_convection_scale, + normalize_gradient_norm_scale, +) diff --git a/exponax/normalized_stepper/convection.py b/exponax/normalized_stepper/convection.py new file mode 100644 index 0000000..04ae0fa --- /dev/null +++ b/exponax/normalized_stepper/convection.py @@ -0,0 +1,110 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array + + +class NormalizedConvectionStepper(BaseStepper): + normalized_coefficients: list[float] + normalized_convection_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + *, + dt: float = 0.1, + normalized_coefficients: list[float] = [0.0, 0.0, 0.01 * 0.1], + normalized_convection_scale: float = 0.5, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + By default: Behaves like a Burgers with + + ``` Burgers( + D=D, L=1, N=N, dt=dt, diffusivity=0.01, + ) + ``` + + If you set `L=2 * jnp.pi` of your unnormalized scenario, then you have + to set your coefficients to `alpha_i * dt` (make sure to use the same dt + as is used here as the keyword based argument). + + If you set `L=1` of your unnormalized scenario, then you have to set + your coefficients to `alpha_i * dt / (2 * jnp.pi)^s` (make sure to use + the same dt as is used here as the keyword based argument) **and** set + your convection scale to whatever you had prior multiplied by 2 * + jnp.pi. + + If you set `L=L` of your unnormalized scenario, then you have to set + your coefficients to `alpha_i * dt * (L / (2 * jnp.pi))^s` (make sure to + use the same dt as is used here as the keyword based argument) **and** + set your convection scale to whatever you had prior multiplied by 2 * + jnp.pi / L. + + number of channels grow with number of spatial dimensions + + **Arguments:** + + - `num_spatial_dims`: number of spatial dimensions + - `num_points`: number of points in each spatial dimension + - `dt`: time step (default: 0.1) + - `normalized_coefficients`: coefficients for the linear operator, + `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative (default: [0.0, 0.0, 0.01 * 0.1] refers to a diffusion + (2nd) order term) + - `normalized_convection_scale`: convection scale for the nonlinear + function (default: 0.5) + - `order`: order of exponential time differencing Runge Kutta method, + can be 1, 2, 3, 4 (default: 2) + - `dealiasing_fraction`: fraction of the wavenumbers being kept before + applying any nonlinearity (default: 2/3) + - `n_circle_points`: number of points to use for the complex contour + integral when computing coefficients for the exponential time + differencing Runge Kutta method (default: 16) + - `circle_radius`: radius of the complex contour integral when computing + coefficients for the exponential time differencing Runge Kutta method + (default: 1.0) + """ + self.normalized_coefficients = normalized_coefficients + self.normalized_convection_scale = normalized_convection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=1.0, # Derivative operator is just scaled with 2 * jnp.pi + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator(self, derivative_operator: Array) -> Array: + # Now the linear operator is unscaled + linear_operator = sum( + jnp.sum( + c / self.dt * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.normalized_coefficients) + ) + return linear_operator + + def _build_nonlinear_fun(self, derivative_operator: Array): + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.normalized_convection_scale, + ) diff --git a/exponax/normalized_stepper/gradient_norm.py b/exponax/normalized_stepper/gradient_norm.py new file mode 100644 index 0000000..40f6a0b --- /dev/null +++ b/exponax/normalized_stepper/gradient_norm.py @@ -0,0 +1,86 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import GradientNormNonlinearFun +from jaxtyping import Complex, Float, Array + + +class NormalizedGradientNormStepper(BaseStepper): + normalized_coefficients: list[float] + normalized_gradient_norm_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + *, + dt: float = 0.1, + normalized_coefficients: list[float] = [0.0, 0.0, 0.01 * 0.1], + normalized_gradient_norm_scale: float = 0.5, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + the number of channels do **not** grow with the number of spatial + dimensions. They are always 1. + + **Arguments:** + - `num_spatial_dims`: number of spatial dimensions + - `num_points`: number of points in each spatial dimension + - `dt`: time step (default: 0.1) + - `normalized_coefficients`: coefficients for the linear operator, + `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative (default: [0.0, 0.0, 0.01 * 0.1] refers to a diffusion + operator) + - `normalized_gradient_norm_scale`: scale for the gradient norm + (default: 0.5) + - `order`: order of the derivative operator (default: 2) + - `dealiasing_fraction`: fraction of the wavenumbers being kept before + applying any nonlinearity (default: 2/3) + - `n_circle_points`: number of points to use for the complex contour + integral when computing coefficients for the exponential time + differencing Runge Kutta method (default: 16) + - `circle_radius`: radius of the complex contour integral when computing + coefficients for the exponential time differencing Runge Kutta method + (default: 1.0) + """ + self.normalized_coefficients = normalized_coefficients + self.normalized_gradient_norm_scale = normalized_gradient_norm_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=1.0, # Derivative operator is just scaled with 2 * jnp.pi + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator(self, derivative_operator: Array) -> Array: + linear_operator = sum( + jnp.sum( + c / self.dt * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.normalized_coefficients) + ) + return linear_operator + + def _build_nonlinear_fun(self, derivative_operator: Array): + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + scale=self.normalized_gradient_norm_scale, + zero_mode_fix=True, + ) diff --git a/exponax/normalized_stepper/linear.py b/exponax/normalized_stepper/linear.py new file mode 100644 index 0000000..e39d804 --- /dev/null +++ b/exponax/normalized_stepper/linear.py @@ -0,0 +1,53 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ZeroNonlinearFun +from jaxtyping import Complex, Float, Array + + +class NormalizedLinearStepper(BaseStepper): + normalized_coefficients: list[float] + + def __init__( + self, + num_spatial_dims: int, + num_points: int, + *, + normalized_coefficients: list[float] = [0.0, -0.5, 0.01], + ): + """ + By default: advection-diffusion with normalized advection of 0.5, and + normalized diffusion of 0.01. + + Take care of the signs! + """ + self.normalized_coefficients = normalized_coefficients + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=1.0, # Derivative operator is just scaled with 2 * jnp.pi + num_points=num_points, + dt=1.0, + num_channels=1, + order=0, + ) + + def _build_linear_operator(self, derivative_operator: Array) -> Array: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.normalized_coefficients) + ) + return linear_operator + + def _build_nonlinear_fun(self, derivative_operator: Array): + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + ) diff --git a/exponax/normalized_stepper/utils.py b/exponax/normalized_stepper/utils.py new file mode 100644 index 0000000..f0e20e9 --- /dev/null +++ b/exponax/normalized_stepper/utils.py @@ -0,0 +1,76 @@ +import jax.numpy as jnp + + +def normalize_coefficients( + domain_extent: float, + dt: float, + coefficients: tuple[float], +) -> tuple[float]: + """ + Normalize the coefficients to a linear time stepper to be used with the + normalized linear stepper. + + **Arguments:** + - `domain_extent`: extent of the domain + - `dt`: time step + - `coefficients`: coefficients for the linear operator, `coefficients[i]` is + the coefficient for the `i`-th derivative + """ + normalized_coefficients = tuple( + c * dt / (domain_extent**i) for i, c in enumerate(coefficients) + ) + return normalized_coefficients + + +def denormalize_coefficients( + domain_extent: float, + dt: float, + normalized_coefficients: tuple[float], +) -> tuple[float]: + """ + Denormalize the coefficients as they were used in the normalized linear to + then be used again in a regular linear stepper. + + **Arguments:** + - `domain_extent`: extent of the domain + - `dt`: time step + - `normalized_coefficients`: coefficients for the linear operator, + `normalized_coefficients[i]` is the coefficient for the `i`-th + derivative + """ + coefficients = tuple( + c_n / dt * domain_extent**i for i, c_n in enumerate(normalized_coefficients) + ) + return coefficients + + +def normalize_convection_scale( + domain_extent: float, + convection_scale: float, +) -> float: + normalized_convection_scale = convection_scale / domain_extent + return normalized_convection_scale + + +def denormalize_convection_scale( + domain_extent: float, + normalized_convection_scale: float, +) -> float: + convection_scale = normalized_convection_scale * domain_extent + return convection_scale + + +def normalize_gradient_norm_scale( + domain_extent: float, + gradient_norm_scale: float, +): + normalized_gradient_norm_scale = gradient_norm_scale / jnp.square(domain_extent) + return normalized_gradient_norm_scale + + +def denormalize_gradient_norm_scale( + domain_extent: float, + normalized_gradient_norm_scale: float, +): + gradient_norm_scale = normalized_gradient_norm_scale * jnp.square(domain_extent) + return gradient_norm_scale diff --git a/exponax/poisson.py b/exponax/poisson.py new file mode 100644 index 0000000..e5a870f --- /dev/null +++ b/exponax/poisson.py @@ -0,0 +1,102 @@ +import jax +import jax.numpy as jnp +import equinox as eqx +from jaxtyping import Array, Float, Complex + +from .spectral import build_derivative_operator, build_laplace_operator, spatial_shape + + +class Poisson(eqx.Module): + num_spatial_dims: int + domain_extent: float + num_points: int + dx: float + + _inv_operator: Complex[Array, "1 ... N"] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + *, + order=2, + ): + """ + Exactly solve the Poisson equation with periodic boundary conditions. + + This "stepper" is different from all other steppers in this package in + that it does not solve a time-dependent PDE. Instead, it solves the + Poisson equation + + $$ u_{xx} = - f $$ + + for a given right hand side $f$. + + It is included for completion. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain. + - `num_points`: The number of points in each spatial dimension. + - `order`: The order of the Poisson equation. Defaults to 2. You can + also set `order=4` for the biharmonic equation. + """ + self.num_spatial_dims = num_spatial_dims + self.domain_extent = domain_extent + self.num_points = num_points + + # Uses the convention that N does **not** include the right boundary + # point + self.dx = domain_extent / num_points + + derivative_operator = build_derivative_operator( + num_spatial_dims, domain_extent, num_points + ) + operator = build_laplace_operator(derivative_operator, order=order) + + # Uses mean zero solution + self._inv_operator = jnp.where(operator == 0, 0.0, 1 / operator) + + def step_fourier( + self, + f_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Solve the Poisson equation in Fourier space. + + **Arguments:** + - `f_hat`: The Fourier transform of the right hand side. + + **Returns:** + - `u_hat`: The Fourier transform of the solution. + """ + return -self._inv_operator * f_hat + + def step( + self, + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Solve the Poisson equation in real space. + + **Arguments:** + - `f`: The right hand side. + + **Returns:** + - `u`: The solution. + """ + f_hat = jnp.fft.rfft(f) + u_hat = self.step_fourier(f_hat) + u = jnp.fft.irfft(u_hat, self.num_points) + return u + + def __call__( + self, + f: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + if f.shape[1:] != spatial_shape(self.num_spatial_dims, self.num_points): + raise ValueError( + f"Shape of f[1:] is {f.shape[1:]} but should be {spatial_shape(self.num_spatial_dims, self.num_points)}" + ) + return self.step(f) diff --git a/exponax/repeated_stepper.py b/exponax/repeated_stepper.py new file mode 100644 index 0000000..f7567ce --- /dev/null +++ b/exponax/repeated_stepper.py @@ -0,0 +1,58 @@ +import equinox as eqx + +from .base_stepper import BaseStepper + +from .utils import repeat + +from jaxtyping import Array, Float, Complex + + +class RepeatedStepper(eqx.Module): + """ + Sugarcoat the utility function `repeat` in a callable PyTree for easy + composition with other equinox modules. + + One intended usage is to get "more accurate" or "more stable" time steppers + that perform substeps. + + The effective time step is `self.stepper.dt * self.n_sub_steps`. In order to + get a time step of X with Y substeps, first instantiate a stepper with a + time step of X/Y and then wrap it in a RepeatedStepper with n_sub_steps=Y. + + **Arguments:** + - `stepper`: The stepper to repeat. + - `n_sub_steps`: The number of substeps to perform. + """ + + stepper: BaseStepper + n_sub_steps: int + + def step( + self, + u: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by self.n_sub_steps time steps given the + current state `u`. + """ + return repeat(self.stepper.step, self.n_sub_steps)(u) + + def step_fourier( + self, + u_hat: Complex[Array, "C ... (N//2)+1"], + ) -> Complex[Array, "C ... (N//2)+1"]: + """ + Step the PDE forward in time by self.n_sub_steps time steps given the + current state `u_hat` in real-valued Fourier space. + """ + return repeat(self.stepper.step_fourier, self.n_sub_steps)(u_hat) + + def __call__( + self, + u: Float[Array, "C ... N"], + ) -> Float[Array, "C ... N"]: + """ + Step the PDE forward in time by self.n_sub_steps time steps given the + current state `u`. + """ + return repeat(self.stepper, self.n_sub_steps)(u) diff --git a/exponax/sample_stepper/__init__.py b/exponax/sample_stepper/__init__.py new file mode 100644 index 0000000..63f198d --- /dev/null +++ b/exponax/sample_stepper/__init__.py @@ -0,0 +1,32 @@ +from .burgers import Burgers +from .convection import GeneralConvectionStepper +from .gradient_norm import GeneralGradientNormStepper +from .korteveg_de_vries import KortevegDeVries +from .kuramoto_sivashinsky import ( + KuramotoSivashinsky, + KuramotoSivashinskyConservative, +) +from .linear import ( + Advection, + Diffusion, + AdvectionDiffusion, + Dispersion, + HyperDiffusion, + GeneralLinearStepper, +) +from .navier_stokes import ( + NavierStokesVorticity2d, + KolmogorovFlowVorticity2d, +) +from .nikolaevskiy import ( + Nikolaevskiy, + NikolaevskiyConservative, +) +from .reaction import ( + SwiftHohenberg, + GrayScott, + FisherKPP, + AllenCahn, + CahnHilliard, + BelousovZhabotinsky, +) diff --git a/exponax/sample_stepper/burgers.py b/exponax/sample_stepper/burgers.py new file mode 100644 index 0000000..c97a8f0 --- /dev/null +++ b/exponax/sample_stepper/burgers.py @@ -0,0 +1,63 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class Burgers(BaseStepper): + diffusivity: float + convection_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.1, + convection_scale: float = 0.5, + order=2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivity = diffusivity + self.convection_scale = convection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, # Number of channels grows with dimension + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + + # The linear operator is the same for all D channels + return self.diffusivity * build_laplace_operator(derivative_operator) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.convection_scale, + ) diff --git a/exponax/sample_stepper/convection.py b/exponax/sample_stepper/convection.py new file mode 100644 index 0000000..1a31bae --- /dev/null +++ b/exponax/sample_stepper/convection.py @@ -0,0 +1,75 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array + + +class GeneralConvectionStepper(BaseStepper): + coefficients: list[float] + convection_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + coefficients: list[float] = [0.0, 0.0, 0.01], + convection_scale: float = 0.5, + order=2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Isotropic linear operators! + + By default Burgers equation with diffusivity of 0.01 + + """ + self.coefficients = coefficients + self.convection_scale = convection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.coefficients) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.convection_scale, + zero_mode_fix=False, # Todo: check this + ) diff --git a/exponax/sample_stepper/gradient_norm.py b/exponax/sample_stepper/gradient_norm.py new file mode 100644 index 0000000..2952b39 --- /dev/null +++ b/exponax/sample_stepper/gradient_norm.py @@ -0,0 +1,75 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import GradientNormNonlinearFun +from jaxtyping import Complex, Float, Array + + +class GeneralGradientNormStepper(BaseStepper): + coefficients: list[float] + gradient_norm_scale: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + coefficients: list[float] = [0.0, 0.0, -1.0, 0.0, -1.0], + gradient_norm_scale: float = 0.5, + order=2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Isotropic linear operators! + + By default KS equation (in combustion science format) + + """ + self.coefficients = coefficients + self.gradient_norm_scale = gradient_norm_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.coefficients) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GradientNormNonlinearFun: + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + scale=self.gradient_norm_scale, + zero_mode_fix=False, # Todo: check this + ) diff --git a/exponax/sample_stepper/korteveg_de_vries.py b/exponax/sample_stepper/korteveg_de_vries.py new file mode 100644 index 0000000..24ef04b --- /dev/null +++ b/exponax/sample_stepper/korteveg_de_vries.py @@ -0,0 +1,86 @@ +from typing import Union + +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ConvectionNonlinearFun +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class KortevegDeVries(BaseStepper): + convection_scale: float + pure_dispersivity: Float[Array, "D"] + advect_over_diffuse_dispersivity: Float[Array, "D"] + diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + convection_scale: float = -6 / 2, + pure_dispersivity: Union[Float[Array, "D"], float] = 1.0, + advect_over_diffuse_dispersivity: Union[Float[Array, "D"], float] = 0.0, + diffusivity: float = 0.0, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.convection_scale = convection_scale + if isinstance(pure_dispersivity, float): + pure_dispersivity = jnp.ones(num_spatial_dims) * pure_dispersivity + if isinstance(advect_over_diffuse_dispersivity, float): + advect_over_diffuse_dispersivity = ( + jnp.ones(num_spatial_dims) * advect_over_diffuse_dispersivity + ) + self.pure_dispersivity = pure_dispersivity + self.advect_over_diffuse_dispersivity = advect_over_diffuse_dispersivity + self.diffusivity = diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace_operator = build_laplace_operator(derivative_operator, order=2) + linear_operator = ( + -build_gradient_inner_product_operator( + derivative_operator, self.pure_dispersivity, order=3 + ) + - build_gradient_inner_product_operator( + derivative_operator, self.advect_over_diffuse_dispersivity, order=1 + ) + * laplace_operator + + self.diffusivity * laplace_operator + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + convection_scale=self.convection_scale, + ) diff --git a/exponax/sample_stepper/kuramoto_sivashinsky.py b/exponax/sample_stepper/kuramoto_sivashinsky.py new file mode 100644 index 0000000..fe2c2ed --- /dev/null +++ b/exponax/sample_stepper/kuramoto_sivashinsky.py @@ -0,0 +1,140 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + GradientNormNonlinearFun, + ConvectionNonlinearFun, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class KuramotoSivashinsky(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 1.0, + fourth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Implements the KS equations as used in the combustion community, i.e., + with a gradient-norm nonlinearity instead of the convection nonliearity. + The advantage is that the number of channels is always 1 no matter the + number of spatial dimensions. + """ + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = -self.second_order_diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) - self.fourth_order_diffusivity * build_laplace_operator( + derivative_operator, order=4 + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GradientNormNonlinearFun: + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + scale=0.5, + ) + + +class KuramotoSivashinskyConservative(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 1.0, + fourth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + """ + Using the fluid dynamics form of the KS equation (i.e. similar to the + Burgers equation). This also means that the number of channels grow with + the number of spatial dimensions. + """ + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = -self.second_order_diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) - self.fourth_order_diffusivity * build_laplace_operator( + derivative_operator, order=4 + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + convection_scale=0.5, + ) diff --git a/exponax/sample_stepper/linear.py b/exponax/sample_stepper/linear.py new file mode 100644 index 0000000..7415cd1 --- /dev/null +++ b/exponax/sample_stepper/linear.py @@ -0,0 +1,314 @@ +from typing import Union + +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ZeroNonlinearFun +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class Advection(BaseStepper): + velocity: Float[Array, "D"] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + velocity: Union[Float[Array, "D"], float] = 1.0, + ): + if isinstance(velocity, float): + velocity = jnp.ones(num_spatial_dims) * velocity + self.velocity = velocity + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + # Requires minus to move term to the rhs + return -build_gradient_inner_product_operator( + derivative_operator, self.velocity, order=1 + ) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class Diffusion(BaseStepper): + diffusivity: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.01, + ): + self.diffusivity = diffusivity + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return self.diffusivity * build_laplace_operator(derivative_operator) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class AdvectionDiffusion(BaseStepper): + velocity: Float[Array, "D"] + diffusivity: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + velocity: Union[Float[Array, "D"], float] = 1.0, + diffusivity: float = 0.01, + ): + if isinstance(velocity, float): + velocity = jnp.ones(num_spatial_dims) * velocity + self.velocity = velocity + self.diffusivity = diffusivity + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return -build_gradient_inner_product_operator( + derivative_operator, self.velocity, order=1 + ) + self.diffusivity * build_laplace_operator(derivative_operator) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class Dispersion(BaseStepper): + dispersivity: Float[Array, "D"] + advect_on_diffusion: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + dispersivity: Union[Float[Array, "D"], float] = 1.0, + advect_on_diffusion: bool = False, + ): + if isinstance(dispersivity, float): + dispersivity = jnp.ones(num_spatial_dims) * dispersivity + self.dispersivity = dispersivity + self.advect_on_diffusion = advect_on_diffusion + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + if self.advect_on_diffusion: + laplace_operator = build_laplace_operator(derivative_operator) + advection_operator = build_gradient_inner_product_operator( + derivative_operator, self.dispersivity, order=1 + ) + linear_operator = advection_operator * laplace_operator + else: + linear_operator = build_gradient_inner_product_operator( + derivative_operator, self.dispersivity, order=3 + ) + + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class HyperDiffusion(BaseStepper): + hyper_diffusivity: float + diffuse_on_diffuse: bool + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + hyper_diffusivity: float = 1.0, + diffuse_on_diffuse: bool = False, + ): + self.hyper_diffusivity = hyper_diffusivity + self.diffuse_on_diffuse = diffuse_on_diffuse + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + # Use minus sign to have diffusion work in "correct direction" by default + if self.diffuse_on_diffuse: + laplace_operator = build_laplace_operator(derivative_operator) + linear_operator = ( + -self.hyper_diffusivity * laplace_operator * laplace_operator + ) + else: + linear_operator = -self.hyper_diffusivity * build_laplace_operator( + derivative_operator, order=4 + ) + + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) + + +class GeneralLinearStepper(BaseStepper): + coefficients: list[float] + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + coefficients: list[float] = [0.0, -0.1, 0.01], + ): + """ + Isotropic linear operators! + + By default: advection-diffusion with advection of 0.1 and diffusion of + 0.01. + + Take care of the signs! + """ + self.coefficients = coefficients + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=0, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = sum( + jnp.sum( + c * (derivative_operator) ** i, + axis=0, + keepdims=True, + ) + for i, c in enumerate(self.coefficients) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ZeroNonlinearFun: + return ZeroNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=1.0, + ) diff --git a/exponax/sample_stepper/navier_stokes.py b/exponax/sample_stepper/navier_stokes.py new file mode 100644 index 0000000..48f360e --- /dev/null +++ b/exponax/sample_stepper/navier_stokes.py @@ -0,0 +1,127 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + VorticityConvection2d, + VorticityConvection2dKolmogorov, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class NavierStokesVorticity2d(BaseStepper): + diffusivity: float + drag: float + dealiasing_fraction: float + + def __init__( + self, + # Does not require D argument as it is fixed to 2 + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.01, + drag: float = 0.0, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivity = diffusivity + self.drag = drag + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=2, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return self.diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) + self.drag * build_laplace_operator(derivative_operator, order=0) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> VorticityConvection2d: + return VorticityConvection2d( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) + + +class KolmogorovFlowVorticity2d(BaseStepper): + diffusivity: float + drag: float + injection_mode: int + injection_scale: float + dealiasing_fraction: float + + def __init__( + self, + # Does not require D argument as it is fixed to 2 + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivity: float = 0.001, + drag: float = -0.1, + injection_mode: int = 4, + injection_scale: float = 1.0, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivity = diffusivity + self.drag = drag + self.injection_mode = injection_mode + self.injection_scale = injection_scale + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=2, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + return self.diffusivity * build_laplace_operator( + derivative_operator, order=2 + ) + self.drag * build_laplace_operator(derivative_operator, order=0) + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> VorticityConvection2dKolmogorov: + return VorticityConvection2dKolmogorov( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + injection_mode=self.injection_mode, + injection_scale=self.injection_scale, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) diff --git a/exponax/sample_stepper/nikolaevskiy.py b/exponax/sample_stepper/nikolaevskiy.py new file mode 100644 index 0000000..a99ba17 --- /dev/null +++ b/exponax/sample_stepper/nikolaevskiy.py @@ -0,0 +1,141 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + GradientNormNonlinearFun, + ConvectionNonlinearFun, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class Nikolaevskiy(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + sixth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 0.1, + fourth_order_diffusivity: float = 1.0, + sixth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.sixth_order_diffusivity = sixth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = ( + self.second_order_diffusivity + * build_laplace_operator(derivative_operator, order=2) + + self.fourth_order_diffusivity + * build_laplace_operator(derivative_operator, order=4) + + self.sixth_order_diffusivity + * build_laplace_operator(derivative_operator, order=6) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GradientNormNonlinearFun: + return GradientNormNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + scale=0.5, + ) + + +class NikolaevskiyConservative(BaseStepper): + second_order_diffusivity: float + fourth_order_diffusivity: float + sixth_order_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + second_order_diffusivity: float = 0.1, + fourth_order_diffusivity: float = 1.0, + sixth_order_diffusivity: float = 1.0, + dealiasing_fraction: float = 2 / 3, + order: int = 2, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.second_order_diffusivity = second_order_diffusivity + self.fourth_order_diffusivity = fourth_order_diffusivity + self.sixth_order_diffusivity = sixth_order_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=num_spatial_dims, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + linear_operator = ( + self.second_order_diffusivity + * build_laplace_operator(derivative_operator, order=2) + + self.fourth_order_diffusivity + * build_laplace_operator(derivative_operator, order=4) + + self.sixth_order_diffusivity + * build_laplace_operator(derivative_operator, order=6) + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> ConvectionNonlinearFun: + return ConvectionNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + zero_mode_fix=True, + convection_scale=0.5, + ) diff --git a/exponax/sample_stepper/reaction.py b/exponax/sample_stepper/reaction.py new file mode 100644 index 0000000..9ea8645 --- /dev/null +++ b/exponax/sample_stepper/reaction.py @@ -0,0 +1,350 @@ +import jax.numpy as jnp + +from jax import Array + +from ..base_stepper import BaseStepper +from ..nonlinear_functions import ( + PolynomialNonlinearFun, + GrayScottNonlinearFun, + CahnHilliardNonlinearFun, + BelousovZhabotinskyNonlinearFun, +) +from jaxtyping import Complex, Float, Array +from ..spectral import build_laplace_operator, build_gradient_inner_product_operator + + +class SwiftHohenberg(BaseStepper): + g: float + r: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + g: float = 1.0, + r: float = 0.7, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.g = g + self.r = r + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = self.r - (1 + laplace) ** 2 + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> PolynomialNonlinearFun: + return PolynomialNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + coefficients=[0.0, 0.0, self.g, -1.0], + ) + + +class GrayScott(BaseStepper): + epsilon_1: float + epsilon_2: float + b: float + d: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + epsilon_1: float = 0.00002, + epsilon_2: float = 0.00001, + b: float = 0.04, + d: float = 0.1, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.epsilon_1 = epsilon_1 + self.epsilon_2 = epsilon_2 + self.b = b + self.d = d + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=2, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "2 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = jnp.concatenate( + [ + self.epsilon_1 * laplace, + self.epsilon_2 * laplace, + ] + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> GrayScottNonlinearFun: + return GrayScottNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + b=self.b, + d=self.d, + dealiasing_fraction=self.dealiasing_fraction, + ) + + +### !!! Below models lack validation ### + + +class FisherKPP(BaseStepper): + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + order: int = 2, + dealiasing_fraction: float = 2 / 3, + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = laplace + 1.0 + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> PolynomialNonlinearFun: + return PolynomialNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + coefficients=[0.0, 0.0, -1.0], + ) + + +class AllenCahn(BaseStepper): + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = laplace + 1.0 + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> PolynomialNonlinearFun: + return PolynomialNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + coefficients=[0.0, 0.0, 0.0, -1.0], + ) + + +class CahnHilliard(BaseStepper): + hyper_diffusivity: float + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + hyper_diffusivity: float = 0.2, + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.hyper_diffusivity = hyper_diffusivity + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=1, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "1 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + bi_laplace = build_laplace_operator(derivative_operator, order=4) + linear_operator = -self.hyper_diffusivity * bi_laplace - laplace + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> CahnHilliardNonlinearFun: + return CahnHilliardNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) + + +class BelousovZhabotinsky(BaseStepper): + diffusivities: list[float] + dealiasing_fraction: float + + def __init__( + self, + num_spatial_dims: int, + domain_extent: float, + num_points: int, + dt: float, + *, + diffusivities: list[float] = [1e-5, 2e-5, 1e-5], + order: int = 2, + dealiasing_fraction: float = 1 + / 2, # Needs lower value due to cubic nonlinearity + n_circle_points: int = 16, + circle_radius: float = 1.0, + ): + self.diffusivities = diffusivities + self.dealiasing_fraction = dealiasing_fraction + super().__init__( + num_spatial_dims=num_spatial_dims, + domain_extent=domain_extent, + num_points=num_points, + dt=dt, + num_channels=3, + order=order, + n_circle_points=n_circle_points, + circle_radius=circle_radius, + ) + + def _build_linear_operator( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> Complex[Array, "3 ... (N//2)+1"]: + laplace = build_laplace_operator(derivative_operator, order=2) + linear_operator = jnp.concatenate( + [ + self.diffusivities[0] * laplace, + self.diffusivities[1] * laplace, + self.diffusivities[2] * laplace, + ] + ) + return linear_operator + + def _build_nonlinear_fun( + self, + derivative_operator: Complex[Array, "D ... (N//2)+1"], + ) -> BelousovZhabotinskyNonlinearFun: + return BelousovZhabotinskyNonlinearFun( + num_spatial_dims=self.num_spatial_dims, + num_points=self.num_points, + num_channels=self.num_channels, + derivative_operator=derivative_operator, + dealiasing_fraction=self.dealiasing_fraction, + ) diff --git a/exponax/spectral.py b/exponax/spectral.py new file mode 100644 index 0000000..135b23a --- /dev/null +++ b/exponax/spectral.py @@ -0,0 +1,412 @@ +import jax.numpy as jnp +from jaxtyping import Array, Float, Complex, PyTree, PRNGKeyArray, Bool +from typing import Union + + +def build_wavenumbers( + num_spatial_dims: int, + num_points: int, + *, + indexing: str = "ij", +) -> Float[Array, "D ... (N//2)+1"]: + """ + Setup an array containing integer coordinates of wavenumbers associated with + a "num_spatial_dims"-dimensional rfft (real-valued FFT) + `jax.numpy.fft.rfftn`. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. + """ + right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) + other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) + + wavenumber_list = [ + other_wavenumbers, + ] * (num_spatial_dims - 1) + [ + right_most_wavenumbers, + ] + + wavenumbers = jnp.stack( + jnp.meshgrid(*wavenumber_list, indexing=indexing), + ) + + return wavenumbers + + +def build_scaled_wavenumbers( + D: int, + L: float, + N: int, + *, + indexing: str = "ij", +) -> Float[Array, "D ... (N//2)+1"]: + """ + Setup an array containing scaled wavenumbers associated with a + "num_spatial_dims"-dimensional rfft (real-valued FFT) + `jax.numpy.fft.rfftn`. Scaling is done by `2 * pi / L`. + + **Arguments:** + - `D`: The number of spatial dimensions. + - `L`: The domain extent. + - `N`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `wavenumbers`: An array of wavenumber integer coordinates, shape + `(D, ..., (N//2)+1)`. + """ + scale = 2 * jnp.pi / L + wavenumbers = build_wavenumbers(D, N, indexing=indexing) + return scale * wavenumbers + + +def derivative( + field: Float[Array, "C ... N"], + domain_extent: float, + *, + order: int = 1, + indexing: str = "ij", +) -> Union[Float[Array, "C D ... (N//2)+1"], Float[Array, "D ... (N//2)+1"]]: + """ + Perform the spectral derivative of a field. In higher dimensions, this + defaults to the gradient (the collection of all partial derivatives). In 1d, + the resulting channel dimension holds the derivative. If the function is + called with an d-dimensional field which has 1 channel, the result will be a + d-dimensional field with d channels (one per partial derivative). If the + field originally had C channels, the result will be a matrix field with C + rows and d columns. + + Note that applying this operator twice will produce issues at the Nyquist if + the number of degrees of freedom N is even. For this, consider also using + the order option. + + **Arguments:** + - `field`: The field to differentiate, shape `(C, ..., N,)`. `C` can be + `1` for a scalar field or `D` for a vector field. + - `L`: The domain extent. + - `order`: The order of the derivative. Default is `1`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `field_der`: The derivative of the field, shape `(C, D, ..., + (N//2)+1)` or `(D, ..., (N//2)+1)`. + """ + channel_shape = field.shape[0] + spatial_shape = field.shape[1:] + D = len(spatial_shape) + N = spatial_shape[0] + derivative_operator = build_derivative_operator( + D, domain_extent, N, indexing=indexing + ) + ## I decided to not use this fix + + # # Required for even N, no effect for odd N + # derivative_operator_fixed = ( + # derivative_operator * nyquist_filter_mask(D, N) + # ) + derivative_operator_fixed = derivative_operator**order + + field_hat = jnp.fft.rfftn(field, axes=space_indices(D)) + if channel_shape == 1: + # Do not introduce another channel axis + field_der_hat = derivative_operator_fixed * field_hat + else: + # Create a "derivative axis" right after the channel axis + field_der_hat = field_hat[:, None] * derivative_operator_fixed[None, ...] + + field_der = jnp.fft.irfftn(field_der_hat, s=spatial_shape, axes=space_indices(D)) + + return field_der + + +def build_derivative_operator( + num_spatial_dims: int, + domain_extent: float, + num_points: int, + *, + indexing: str = "ij", +) -> Complex[Array, "D ... (N//2)+1"]: + """ + Setup the derivative operator in Fourier space. + + **Arguments:** + - `D`: The number of spatial dimensions. + - `L`: The domain extent. + - `N`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `derivative_operator`: The derivative operator, shape `(D, ..., + N//2+1)`. + """ + return 1j * build_scaled_wavenumbers( + num_spatial_dims, domain_extent, num_points, indexing=indexing + ) + + +def build_laplace_operator( + derivative_operator: Complex[Array, "D ... (N//2)+1"], + *, + order: int = 2, +) -> Complex[Array, "1 ... (N//2)+1"]: + """ + Given the derivative operator of [`build_derivative_operator`], return the + Laplace operator. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `(D, ..., + N//2+1)`. + - `order`: The order of the Laplace operator. Default is `2`. + + **Returns:** + - `laplace_operator`: The Laplace operator, shape `(1, ..., N//2+1)`. + """ + if order % 2 != 0: + raise ValueError("Order must be even.") + + return jnp.sum(derivative_operator**order, axis=0, keepdims=True) + + +def build_gradient_inner_product_operator( + derivative_operator: Complex[Array, "D ... (N//2)+1"], + velocity: Float[Array, "D"], + *, + order: int = 1, +) -> Complex[Array, "1 ... (N//2)+1"]: + """ + Given the derivative operator of [`build_derivative_operator`] and a velocity + field, return the operator that computes the inner product of the gradient + with the velocity. + + **Arguments:** + - `derivative_operator`: The derivative operator, shape `(D, ..., + N//2+1)`. + - `velocity`: The velocity field, shape `(D,)`. + - `order`: The order of the gradient. Default is `1`. + + **Returns:** + - `operator`: The operator, shape `(1, ..., N//2+1)`. + """ + if order % 2 != 1: + raise ValueError("Order must be odd.") + + if velocity.shape != (derivative_operator.shape[0],): + raise ValueError( + f"Expected velocity shape to be {derivative_operator.shape[0]}, got {velocity.shape}." + ) + + # Need to move the channel/dimension axis last to enable autobroadcast over + # the arbitrary number of spatial axes, Then we can move this singleton axis + # back to the front + operator = jnp.swapaxes( + jnp.sum( + velocity + * jnp.swapaxes( + derivative_operator**order, + 0, + -1, + ), + axis=-1, + keepdims=True, + ), + 0, + -1, + ) + + return operator + + +def space_indices(num_spatial_dims: int) -> tuple[int, ...]: + """ + Returns the indices within a field array that correspond to the spatial + dimensions. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + + **Returns:** + - `indices`: The indices of the spatial dimensions. + """ + return tuple(range(-num_spatial_dims, 0)) + + +def spatial_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: + """ + Returns the shape of a spatial field array. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + - `shape`: The shape of the spatial field array. + """ + return (num_points,) * num_spatial_dims + + +def wavenumber_shape(num_spatial_dims: int, num_points: int) -> tuple[int, ...]: + """ + Returns the spatial shape of a field in Fourier space (assuming the usage of + rfft, `jax.numpy.fft.rfftn`). + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + - `shape`: The shape of the spatial field array. + """ + return (num_points,) * (num_spatial_dims - 1) + (num_points // 2 + 1,) + + +def low_pass_filter_mask( + num_spatial_dims: int, + num_points: int, + *, + cutoff: int, + axis_separate: bool = True, + indexing: str = "ij", +) -> Bool[Array, "1 ... N"]: + """ + Create a low-pass filter mask in Fourier space. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `cutoff`: The cutoff wavenumber. + - `axis_separate`: Whether to apply the cutoff to each axis separately. + Default is `True`. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `mask`: The low-pass filter mask, shape `(1, ..., N//2+1)`. + """ + wavenumbers = build_wavenumbers(num_spatial_dims, num_points, indexing=indexing) + + if axis_separate: + mask = True + for wn_grid in wavenumbers: + mask = mask & (jnp.abs(wn_grid) <= cutoff) + else: + mask = jnp.linalg.norm(mask, axis=0) <= cutoff + + mask = mask[jnp.newaxis, ...] + + return mask + + +def nyquist_filter_mask( + num_spatial_dims: int, + num_points: int, +) -> Bool[Array, "1 ... N"]: + """ + Creates mask that if multiplied with a field in Fourier space will remove + the Nyquist mode. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + + **Returns:** + - `mask`: The Nyquist filter mask, shape `(1, ..., N//2+1)`. + """ + if num_points % 2 == 1: + # Odd number of degrees of freedom (no issue with the Nyquist mode) + return jnp.ones( + (1, *wavenumber_shape(num_spatial_dims, num_points)), dtype=bool + ) + else: + # Even number of dof (hence the Nyquist only appears in the negative + # wavenumbers. This is problematic because the rfft in D >=2 has + # multiple FFTs after the rFFT) + nyquist_mode = num_points // 2 + 1 + mode_below_nyquist = nyquist_mode - 1 + return low_pass_filter_mask( + num_spatial_dims, + num_points, + cutoff=mode_below_nyquist - 1, + axis_separate=True, + ) + + # # Todo: Do we need the below? + # wavenumbers = build_wavenumbers(D, N, scaled=False) + # mask = True + # for wn_grid in wavenumbers: + # mask = mask & (wn_grid != -mode_below_nyquist) + # return mask + + +def build_scaling_array( + num_spatial_dims: int, + num_points: int, + *, + indexing: str = "ij", +) -> Float[Array, "1 ... (N//2)+1"]: + """ + Creates an array of the values that would be seen in the result of a + (real-valued) Fourier transform of a signal of amplitude 1. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `num_points`: The number of points in each spatial dimension. + - `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`. + Either `"ij"` or `"xy"`. Default is `"ij"`. + + **Returns:** + - `scaling`: The scaling array, shape `(1, ..., N//2+1)`. + """ + right_most_wavenumbers = jnp.fft.rfftfreq(num_points, 1 / num_points) + other_wavenumbers = jnp.fft.fftfreq(num_points, 1 / num_points) + + right_most_scaling = jnp.where( + right_most_wavenumbers == 0, + num_points, + num_points / 2, + ) + other_scaling = jnp.where( + other_wavenumbers == 0, + num_points, + num_points / 2, + ) + + # If N is even, special treatment for the Nyquist mode + if num_points % 2 == 0: + # rfft has the Nyquist mode as positive wavenumber + right_most_scaling = jnp.where( + right_most_wavenumbers == num_points // 2, + num_points, + right_most_scaling, + ) + # standard fft has the Nyquist mode as negative wavenumber + other_scaling = jnp.where( + other_wavenumbers == -num_points // 2, + num_points, + other_scaling, + ) + + scaling_list = [ + other_scaling, + ] * (num_spatial_dims - 1) + [ + right_most_scaling, + ] + + scaling = jnp.prod( + jnp.stack( + jnp.meshgrid(*scaling_list, indexing=indexing), + ), + axis=0, + keepdims=True, + ) + + return scaling diff --git a/exponax/utils.py b/exponax/utils.py new file mode 100644 index 0000000..faab6fc --- /dev/null +++ b/exponax/utils.py @@ -0,0 +1,365 @@ +from typing import Union + +import jax +import jax.numpy as jnp +import jax.random as jr +import jax.tree_util as jtu +from jaxtyping import Array, Float, Complex, PyTree, PRNGKeyArray +from typing import Callable, Tuple + + +def get_grid( + num_spatial_dims: int, + domain_extent: float, + num_points: int, + *, + full: bool = False, + zero_centered: bool = False, + indexing: str = "ij", +) -> Float[Array, "D ... N"]: + """ + Return a grid in the spatial domain. A grid in d dimensions is an array of + shape (d,) + (num_points,)*d with the first axis representing all coordiate + inidices. + + Notice, that if `num_spatial_dims = 1`, the returned array has a singleton + dimension in the first axis, i.e., the shape is `(1, num_points)`. + + **Arguments:** + - `num_spatial_dims`: The number of spatial dimensions. + - `domain_extent`: The extent of the domain in each spatial dimension. + - `num_points`: The number of points in each spatial dimension. + - `full`: Whether to include the right boundary point in the grid. + Default: `False`. The right point is redundant for periodic boundary + conditions and is not considered a degree of freedom. Use this + option, for example, if you need a full grid for plotting. + - `zero_centered`: Whether to center the grid around zero. Default: + `False`. By default the grid considers a domain of (0, + domain_extent)^(num_spatial_dims). + - `indexing`: The indexing convention to use. Default: `'ij'`. + + **Returns:** + - `grid`: The grid in the spatial domain. Shape: `(num_spatial_dims, + ..., num_points)`. + """ + if full: + grid_1d = jnp.linspace(0, domain_extent, num_points + 1, endpoint=True) + else: + grid_1d = jnp.linspace(0, domain_extent, num_points, endpoint=False) + + if zero_centered: + grid_1d -= domain_extent / 2 + + grid_list = [ + grid_1d, + ] * num_spatial_dims + + grid = jnp.stack( + jnp.meshgrid(*grid_list, indexing=indexing), + ) + + return grid + + +def rollout( + stepper_fn: Union[Callable[[PyTree], PyTree], Callable[[PyTree, PyTree], PyTree]], + n: int, + *, + include_init: bool = False, + takes_aux: bool = False, + constant_aux: bool = True, +): + """ + Transform a stepper function into a function that autoregressively (i.e., + recursively applied to its own output) produces a trajectory of length `n`. + + Based on `takes_aux`, the stepper function is either fully automomous, just + mapping state to state, or takes an additional auxiliary input. This can be + a force/control or additional metadata (like physical parameters, or time + for non-autonomous systems). + + Args: + - `stepper_fn`: The time stepper to transform. If `takes_aux = False` + (default), expected signature is `u_next = stepper_fn(u)`, else + `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees + of identical structure, in the easiest case just arrays of same + shape. + - `n`: The number of time steps to rollout the trajectory into the + future. If `include_init = False` (default) produces the `n` steps + into the future. + - `include_init`: Whether to include the initial condition in the + trajectory. If `True`, the arrays in the returning PyTree have shape + `(n + 1, ...)`, else `(n, ...)`. Default: `False`. + - `takes_aux`: Whether the stepper function takes an additional PyTree + as second argument. + - `constant_aux`: Whether the auxiliary input is constant over the + trajectory. If `True`, the auxiliary input is repeated `n` times, + otherwise the leading axis in the PyTree arrays has to be of length + `n`. + + Returns: + - `rollout_stepper_fn`: A function that takes an initial condition `u_0` + and an auxiliary input `aux` (if `takes_aux = True`) and produces + the trajectory by autoregressively applying the stepper `n` times. + If `include_init = True`, the trajectory has shape `(n + 1, ...)`, + else `(n, ...)`. Returns a PyTree of the same structure as the + initial condition, but with an additional leading axis of length + `n`. + """ + + if takes_aux: + + def scan_fn(u, aux): + u_next = stepper_fn(u, aux) + return u_next, u_next + + def rollout_stepper_fn(u_0, aux): + if constant_aux: + aux = jtu.tree_map( + lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), n, axis=0), aux + ) + + _, trj = jax.lax.scan(scan_fn, u_0, aux, length=n) + + if include_init: + trj_with_init = jtu.tree_map( + lambda init, history: jnp.concatenate( + [jnp.expand_dims(init, axis=0), history], axis=0 + ), + u_0, + trj, + ) + return trj_with_init + else: + return trj + + return rollout_stepper_fn + + else: + + def scan_fn(u, _): + u_next = stepper_fn(u) + return u_next, u_next + + def rollout_stepper_fn(u_0): + _, trj = jax.lax.scan(scan_fn, u_0, None, length=n) + + if include_init: + trj_with_init = jtu.tree_map( + lambda init, history: jnp.concatenate( + [jnp.expand_dims(init, axis=0), history], axis=0 + ), + u_0, + trj, + ) + return trj_with_init + else: + return trj + + return rollout_stepper_fn + + +def repeat( + stepper_fn: Union[Callable[[PyTree], PyTree], Callable[[PyTree, PyTree], PyTree]], + n: int, + *, + takes_aux: bool = False, + constant_aux: bool = True, +): + """ + Transform a stepper function into a function that autoregressively (i.e., + recursively applied to its own output) applies the stepper `n` times and + returns the final state. + + Based on `takes_aux`, the stepper function is either fully automomous, just + mapping state to state, or takes an additional auxiliary input. This can be + a force/control or additional metadata (like physical parameters, or time + for non-autonomous systems). + + Args: + - `stepper_fn`: The time stepper to transform. If `takes_aux = False` + (default), expected signature is `u_next = stepper_fn(u)`, else + `u_next = stepper_fn(u, aux)`. `u` and `u_next` need to be PyTrees + of identical structure, in the easiest case just arrays of same + shape. + - `n`: The number of times to apply the stepper. + - `takes_aux`: Whether the stepper function takes an additional PyTree + as second argument. + - `constant_aux`: Whether the auxiliary input is constant over the + trajectory. If `True`, the auxiliary input is repeated `n` times, + otherwise the leading axis in the PyTree arrays has to be of length + `n`. + + Returns: + - `repeated_stepper_fn`: A function that takes an initial condition + `u_0` and an auxiliary input `aux` (if `takes_aux = True`) and + produces the final state by autoregressively applying the stepper + `n` times. Returns a PyTree of the same structure as the initial + condition. + """ + + if takes_aux: + + def scan_fn(u, aux): + u_next = stepper_fn(u, aux) + return u_next, None + + def repeated_stepper_fn(u_0, aux): + if constant_aux: + aux = jtu.tree_map( + lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), n, axis=0), aux + ) + + final, _ = jax.lax.scan(scan_fn, u_0, aux, length=n) + return final + + return repeated_stepper_fn + + else: + + def scan_fn(u, _): + u_next = stepper_fn(u) + return u_next, None + + def repeated_stepper_fn(u_0): + final, _ = jax.lax.scan(scan_fn, u_0, None, length=n) + return final + + return repeated_stepper_fn + + +def stack_sub_trajectories( + trj: PyTree[Float[Array, "n_timesteps ..."]], + sub_len: int, +) -> PyTree[Float[Array, "n_stacks sub_len ..."]]: + """ + Slice a trajectory into subtrajectories of length `n` and stack them + together. Useful for rollout training neural operators with temporal mixing. + + !!! Note that this function can produce very large arrays. + + **Arguments:** + - `trj`: The trajectory to slice. Expected shape: `(n_timesteps, ...)`. + - `sub_len`: The length of the subtrajectories. If you want to perform rollout + training with k steps, note that `n=k+1` to also have an initial + condition in the subtrajectories. + + **Returns:** + - `sub_trjs`: The stacked subtrajectories. Expected shape: `(n_stacks, + n, ...)`. `n_stacks` is the number of subtrajectories stacked + together, i.e., `n_timesteps - n + 1`. + """ + n_time_steps = [l.shape[0] for l in jtu.tree_leaves(trj)] + + if len(set(n_time_steps)) != 1: + raise ValueError( + "All arrays in trj must have the same number of time steps in the leading axis" + ) + else: + n_time_steps = n_time_steps[0] + + if sub_len > n_time_steps: + raise ValueError( + "n must be smaller than or equal to the number of time steps in trj" + ) + + n_sub_trjs = n_time_steps - sub_len + 1 + + sub_trjs = jtu.tree_map( + lambda trj: jnp.stack( + [trj[i : i + sub_len] for i in range(n_sub_trjs)], axis=0 + ), + trj, + ) + + return sub_trjs + + +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation + + +def get_animation(trj, *, vlim=(-1, 1)): + fig, ax = plt.subplots() + im = ax.imshow( + trj[0].squeeze().T, vmin=vlim[0], vmax=vlim[1], cmap="RdBu_r", origin="lower" + ) + im.set_data(jnp.zeros_like(trj[0]).squeeze()) + + def animate(i): + im.set_data(trj[i].squeeze().T) + fig.suptitle(f"t_i = {i:04d}") + return im + + plt.close(fig) + + ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False) + + return ani + + +def get_grouped_animation( + trj, *, vlim=(-1, 1), grid=(3, 3), figsize=(10, 10), titles=None +): + """ + trj.shape = (n_trjs, n_timesteps, ...) + """ + fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize) + im_s = [] + for i, ax in enumerate(ax_s.flatten()): + im = ax.imshow( + trj[i, 0].squeeze().T, + vmin=vlim[0], + vmax=vlim[1], + cmap="RdBu_r", + origin="lower", + ) + im.set_data(jnp.zeros_like(trj[i, 0]).squeeze()) + im_s.append(im) + + def animate(i): + for j, im in enumerate(im_s): + im.set_data(trj[j, i].squeeze().T) + if titles is not None: + ax_s.flatten()[j].set_title(titles[j]) + fig.suptitle(f"t_i = {i:04d}") + return im_s + + plt.close(fig) + + ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False) + + return ani + + +def build_ic_set( + ic_generator, + *, + num_points: int, + num_samples: int, + key: PRNGKeyArray, +) -> Float[Array, "S 1 ... N"]: + """ + Generate a set of initial conditions by sampling from a given initial + condition distribution and evaluating the function on the given grid. + + **Arguments:** + - `ic_generator`: A function that takes a PRNGKey and returns a + function that takes a grid and returns a sample from the initial + condition distribution. + - `num_samples`: The number of initial conditions to sample. + - `key`: The PRNGKey to use for sampling. + + **Returns:** + - `ic_set`: The set of initial conditions. Shape: `(S, 1, ..., N)`. + `S = num_samples`. + """ + + def scan_fn(k, _): + k, sub_k = jr.split(k) + ic = ic_generator(num_points, key=sub_k) + return k, ic + + _, ic_set = jax.lax.scan(scan_fn, key, None, length=num_samples) + + return ic_set diff --git a/ks_rollout.png b/ks_rollout.png new file mode 100644 index 0000000..516097e Binary files /dev/null and b/ks_rollout.png differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..95d6f51 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[tool.black] +line-length = 88 +target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +include = '\.pyi?$' +exclude = ''' +( + /( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + | docs + )/ +) +''' diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..52bdaf7 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,8 @@ +[metadata] +description-file = README.md + +[flake8] +# for compatibility with black +max-line-length = 88 +select = C,E,F,W,B,B950 +extend-ignore = E203,E501,E731,W503,F722 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..bfc1fe1 --- /dev/null +++ b/setup.py @@ -0,0 +1,8 @@ +from setuptools import setup + +setup( + name="exponax", + version="0.1.0", + description="JAX-based differentiable simulators for 1d periodic semi-linear PDEs based on exponential time differencing", + author="Felix Koehler", +) diff --git a/tests/test_builtin_solvers.py b/tests/test_builtin_solvers.py new file mode 100644 index 0000000..37f45b0 --- /dev/null +++ b/tests/test_builtin_solvers.py @@ -0,0 +1,131 @@ +import jax +import jax.numpy as jnp +import pytest +import exponax as ex + + +def test_instantiate(): + domain_extent = 10.0 + num_points = 25 + dt = 0.1 + + for num_spatial_dims in [1, 2, 3]: + for simulator in [ + ex.Advection, + ex.Diffusion, + ex.AdvectionDiffusion, + ex.Dispersion, + ex.HyperDiffusion, + ex.Burgers, + ex.KuramotoSivashinsky, + ex.KuramotoSivashinskyConservative, + ex.SwiftHohenberg, + ex.GrayScott, + ex.KortevegDeVries, + ex.FisherKPP, + ex.AllenCahn, + ex.CahnHilliard, + ]: + simulator(num_spatial_dims, domain_extent, num_points, dt) + + for simulator in [ + ex.NavierStokesVorticity2d, + ex.KolmogorovFlowVorticity2d, + ]: + simulator(domain_extent, num_points, dt) + + for num_spatial_dims in [1, 2, 3]: + ex.Poisson(num_spatial_dims, domain_extent, num_points) + + for num_spatial_dims in [1, 2, 3]: + for normalized_simulator in [ + ex.NormalizedLinearStepper, + ex.NormalizedConvectionStepper, + ex.NormalizedGradientNormStepper, + ]: + normalized_simulator(num_spatial_dims, num_points) + + +def test_linear_normalized_stepper(): + num_spatial_dims = 1 + domain_extent = 3.0 + num_points = 50 + dt = 0.1 + + u_0 = ex.RandomTruncatedFourierSeries( + num_spatial_dims, + domain_extent, + cutoff=5, + )(num_points, key=jax.random.PRNGKey(0)) + + for coefficients in ( + [ + 0.5, + ], # drag + [0.0, -0.3], # advection + [0.0, 0.0, 0.01], # diffusion + [0.0, -0.2, 0.01], # advection-diffusion + [0.0, 0.0, 0.0, 0.001], # dispersion + [0.0, 0.0, 0.0, 0.0, -0.0001], # hyperdiffusion + ): + regular_linear_stepper = ex.GeneralLinearStepper( + num_spatial_dims, + domain_extent, + num_points, + dt, + coefficients=coefficients, + ) + normalized_linear_stepper = ex.NormalizedLinearStepper( + num_spatial_dims, + num_points, + normalized_coefficients=ex.normalize_coefficients( + domain_extent, + dt, + coefficients, + ), + ) + + regular_linear_pred = regular_linear_stepper(u_0) + normalized_linear_pred = normalized_linear_stepper(u_0) + + assert regular_linear_pred == pytest.approx(normalized_linear_pred) + + +def test_nonlinear_normalized_stepper(): + num_spatial_dims = 1 + domain_extent = 3.0 + num_points = 50 + dt = 0.1 + diffusivity = 0.1 + convection_scale = 0.5 + + grid = ex.get_grid(num_spatial_dims, domain_extent, num_points) + u_0 = jnp.sin(2 * jnp.pi * grid / domain_extent) + 0.3 + + regular_burgers_stepper = ex.Burgers( + num_spatial_dims, + domain_extent, + num_points, + dt, + diffusivity=diffusivity, + convection_scale=convection_scale, + ) + normalized_burgers_stepper = ex.NormalizedConvectionStepper( + num_spatial_dims, + num_points, + dt=dt, + normalized_coefficients=ex.normalize_coefficients( + domain_extent, + dt, + [0.0, 0.0, diffusivity], + ), + normalized_convection_scale=ex.normalize_convection_scale( + domain_extent, + convection_scale, + ), + ) + + regular_burgers_pred = regular_burgers_stepper(u_0) + normalized_burgers_pred = normalized_burgers_stepper(u_0) + + assert regular_burgers_pred == pytest.approx(normalized_burgers_pred) diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..e617cb9 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,66 @@ +import jax +import jax.numpy as jnp +import pytest +import exponax as ex + +### Linear steppers + +# linear steppers do not make spatial and temporal truncation errors, hence we +# can directly compare them with the analytical solution without performing a +# convergence study + + +def test_advection_1d(): + num_spatial_dims = 1 + domain_extent = 10.0 + num_points = 100 + dt = 0.1 + velocity = 0.1 + + analytical_solution = lambda t, x: jnp.sin( + 4 * 2 * jnp.pi * (x - velocity * t) / domain_extent + ) + + grid = ex.get_grid(num_spatial_dims, domain_extent, num_points) + u_0 = analytical_solution(0.0, grid) + u_1 = analytical_solution(dt, grid) + + stepper = ex.Advection( + num_spatial_dims, + domain_extent, + num_points, + dt, + velocity=velocity, + ) + + u_1_pred = stepper(u_0) + + assert u_1_pred == pytest.approx(u_1, rel=1e-4) + + +def test_diffusion_1d(): + num_spatial_dims = 1 + domain_extent = 10.0 + num_points = 100 + dt = 0.1 + diffusivity = 0.1 + + analytical_solution = lambda t, x: jnp.exp( + -4 * (2 * jnp.pi / domain_extent) ** 2 * diffusivity * t + ) * jnp.sin(4 * 2 * jnp.pi * x / domain_extent) + + grid = ex.get_grid(num_spatial_dims, domain_extent, num_points) + u_0 = analytical_solution(0.0, grid) + u_1 = analytical_solution(dt, grid) + + stepper = ex.Diffusion( + num_spatial_dims, + domain_extent, + num_points, + dt, + diffusivity=diffusivity, + ) + + u_1_pred = stepper(u_0) + + assert u_1_pred == pytest.approx(u_1, rel=1e-4)