Skip to content

Commit

Permalink
Change the way KdV handles the dispersivity argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 13, 2024
1 parent ee57af8 commit 8f38646
Showing 1 changed file with 23 additions and 22 deletions.
45 changes: 23 additions & 22 deletions exponax/stepper/_korteweg_de_vries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import TypeVar, Union
from typing import TypeVar

import jax.numpy as jnp
from jaxtyping import Array, Complex, Float
from jaxtyping import Array, Complex

from .._base_stepper import BaseStepper
from .._spectral import build_gradient_inner_product_operator, build_laplace_operator
Expand All @@ -12,10 +12,10 @@

class KortewegDeVries(BaseStepper):
convection_scale: float
pure_dispersivity: Float[Array, "D"]
advect_over_diffuse_dispersivity: Float[Array, "D"]
dispersivity: float
diffusivity: float
dealiasing_fraction: float
advect_over_diffuse: bool
single_channel: bool

def __init__(
Expand All @@ -26,24 +26,18 @@ def __init__(
dt: float,
*,
convection_scale: float = -6.0,
pure_dispersivity: Union[Float[Array, "D"], float] = 1.0,
dispersivity: float = 1.0,
advect_over_diffuse: bool = False,
single_channel: bool = False,
advect_over_diffuse_dispersivity: Union[Float[Array, "D"], float] = 0.0,
diffusivity: float = 0.0,
order: int = 2,
dealiasing_fraction: float = 2 / 3,
num_circle_points: int = 16,
circle_radius: float = 1.0,
):
self.convection_scale = convection_scale
if isinstance(pure_dispersivity, float):
pure_dispersivity = jnp.ones(num_spatial_dims) * pure_dispersivity
if isinstance(advect_over_diffuse_dispersivity, float):
advect_over_diffuse_dispersivity = (
jnp.ones(num_spatial_dims) * advect_over_diffuse_dispersivity
)
self.pure_dispersivity = pure_dispersivity
self.advect_over_diffuse_dispersivity = advect_over_diffuse_dispersivity
self.dispersivity = dispersivity
self.advect_over_diffuse = advect_over_diffuse
self.diffusivity = diffusivity
self.single_channel = single_channel
self.dealiasing_fraction = dealiasing_fraction
Expand All @@ -69,17 +63,24 @@ def _build_linear_operator(
self,
derivative_operator: Complex[Array, "D ... (N//2)+1"],
) -> Complex[Array, "1 ... (N//2)+1"]:
dispersion_velocity = self.dispersivity * jnp.ones(self.num_spatial_dims)
laplace_operator = build_laplace_operator(derivative_operator, order=2)
linear_operator = (
-build_gradient_inner_product_operator(
derivative_operator, self.pure_dispersivity, order=3
if self.advect_over_diffuse:
linear_operator = (
-build_gradient_inner_product_operator(
derivative_operator, self.advect_over_diffuse_dispersivity, order=1
)
* laplace_operator
+ self.diffusivity * laplace_operator
)
- build_gradient_inner_product_operator(
derivative_operator, self.advect_over_diffuse_dispersivity, order=1
else:
linear_operator = (
-build_gradient_inner_product_operator(
derivative_operator, dispersion_velocity, order=3
)
+ self.diffusivity * laplace_operator
)
* laplace_operator
+ self.diffusivity * laplace_operator
)

return linear_operator

def _build_nonlinear_fun(
Expand Down

0 comments on commit 8f38646

Please sign in to comment.