Skip to content

Commit

Permalink
Improved Viz module (including new wrapped volume renderer) (#5)
Browse files Browse the repository at this point in the history
* Lighweight wrap around volume rendering

* Add gamma correction

* Change default bg color to white

* Change to batch rendering by default

* Add routine for spatio-temporal of 2d states

* Change default to batch rendering

* Add single channel animation routine

* Faceting 3d states

* Facet 3d animation

* Implement spatio temporal facet

* Remove arguments not yet understood by vape

* Add documentation

* Forward changes in interface

* Remove ax arg

* Adapt return structure

* add Docs

* Add docs

* Adapt interfaces

* Change return structure

* Final change to plotting interface

* close figure facets

* Add arguments for compatibility

* Add dummy function

* Add docs

* Add dummy function

* Remove debug print

* Include temporal grid with state 2d animation

* Add docs remove ax

* Add missing time suptitle

* Add docs

* Allow changing cmap

* Start writing tests for viz routines

* Fix formatting
  • Loading branch information
Ceyron authored Jun 12, 2024
1 parent 783fc59 commit 1476cd3
Show file tree
Hide file tree
Showing 7 changed files with 989 additions and 24 deletions.
26 changes: 24 additions & 2 deletions exponax/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,33 @@
can be animated over another axis (some notion of time).
"""

from ._animate import animate_spatio_temporal, animate_state_1d, animate_state_2d
from ._animate import (
animate_spatio_temporal,
animate_state_1d,
animate_state_2d,
animate_state_3d,
)
from ._animate_facet import (
animate_spatial_temporal_facet,
animate_state_1d_facet,
animate_state_2d_facet,
animate_state_3d_facet,
)
from ._plot import (
plot_spatio_temporal,
plot_spatio_temporal_2d,
plot_state_1d,
plot_state_2d,
plot_state_3d,
)
from ._plot import plot_spatio_temporal, plot_state_1d, plot_state_2d
from ._plot_facet import (
plot_spatio_temporal_2d_facet,
plot_spatio_temporal_facet,
plot_state_1d_facet,
plot_state_2d_facet,
plot_state_3d_facet,
)
from ._volume import volume_render_state_3d

# from IPython.display import HTML

Expand All @@ -45,4 +60,11 @@
"animate_state_2d_facet",
"animate_spatio_temporal",
"animate_spatial_temporal_facet",
"volume_render_state_3d",
"plot_state_3d",
"plot_spatio_temporal_2d",
"animate_state_3d",
"plot_state_3d_facet",
"animate_state_3d_facet",
"plot_spatio_temporal_2d_facet",
]
136 changes: 130 additions & 6 deletions exponax/viz/_animate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import TypeVar
from typing import Literal, TypeVar, Union

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float
from matplotlib.animation import FuncAnimation

from .._utils import wrap_bc
from ._plot import plot_spatio_temporal, plot_state_1d, plot_state_2d
from ._volume import volume_render_state_3d, zigzag_alpha

N = TypeVar("N")

Expand Down Expand Up @@ -91,6 +94,7 @@ def animate_spatio_temporal(
trjs: Float[Array, "S T C N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
Expand All @@ -114,6 +118,7 @@ def animate_spatio_temporal(
- `trjs`: The trajectory of states to animate. Must be a four-axis array
with shape `(n_timesteps_outer, n_time_steps, n_channels, n_spatial)`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `cmap`: The colormap to use. Default is `"RdBu_r"`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
Expand All @@ -136,6 +141,7 @@ def animate_spatio_temporal(
plot_spatio_temporal(
trjs[0],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
dt=dt,
include_init=include_init,
Expand All @@ -148,6 +154,7 @@ def animate(i):
plot_spatio_temporal(
trjs[i],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
dt=dt,
include_init=include_init,
Expand All @@ -166,6 +173,7 @@ def animate_state_2d(
trj: Float[Array, "T 1 N N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
cmap: str = "RdBu_r",
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
Expand All @@ -186,6 +194,7 @@ def animate_state_2d(
- `trj`: The trajectory of states to animate. Must be a four-axis array with
shape `(n_timesteps, 1, n_spatial, n_spatial)`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `cmap`: The colormap to use. Default is `"RdBu_r"`.
- `domain_extent`: The extent of the spatial domain. Default is `None`. This
affects the x- and y-axis limits of the plot.
- `dt`: The time step between each frame. Default is `None`. If provided,
Expand All @@ -205,31 +214,146 @@ def animate_state_2d(

fig, ax = plt.subplots()

if dt is not None:
time_range = (0, dt * trj.shape[0])
if not include_init:
time_range = (dt, time_range[1])
if include_init:
temporal_grid = jnp.arange(trj.shape[0])
else:
time_range = (0, trj.shape[0] - 1)
temporal_grid = jnp.arange(1, trj.shape[0] + 1)

if dt is not None:
temporal_grid *= dt

plot_state_2d(
trj[0],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
ax=ax,
)
ax.set_title(f"t = {temporal_grid[0]:.2f}")

def animate(i):
ax.clear()
plot_state_2d(
trj[i],
vlim=vlim,
cmap=cmap,
domain_extent=domain_extent,
ax=ax,
)
ax.set_title(f"t = {temporal_grid[i]:.2f}")

plt.close(fig)

ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False)

return ani


def animate_state_3d(
trj: Float[Array, "T 1 N N N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
bg_color: Union[
Literal["black"],
Literal["white"],
tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
] = "white",
resolution: int = 384,
cmap: str = "RdBu_r",
transfer_function: callable = zigzag_alpha,
distance_scale: float = 10.0,
gamma_correction: float = 2.4,
chunk_size: int = 64,
**kwargs,
):
"""
Animate a trajectory of 3d states as volume renderings.
Requires the input to be a five-axis array with a leading time axis, a
channel axis, and three spatial axes. Only the zeroth dimension in the
channel axis is plotted.
Periodic boundary conditions will be applied to the spatial axes (the state
is wrapped around).
**Arguments**:
- `trj`: The trajectory of states to animate. Must be a five-axis array with
shape `(n_timesteps, 1, n_spatial, n_spatial, n_spatial)`.
- `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
- `domain_extent`: (Unused as of now)
- `dt`: The time step between each frame. Default is `None`. If provided,
a title will be displayed with the current time. If not provided, just
the frames are counted.
- `include_init`: Whether to the state starts at an initial condition (t=0)
or at the first frame in the trajectory. This affects is the the time
range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
- `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple
of RGBA values. Default is `"white"`.
- `resolution`: The resolution of the output image (affects render time).
Default is `384`.
- `cmap`: The colormap to use. Default is `"RdBu_r"`.
- `transfer_function`: The transfer function to use. Default is `zigzag_alpha`.
- `distance_scale`: The distance scale. Default is `10.0`.
- `gamma_correction`: The gamma correction. Default is `2.4`.
- `chunk_size`: The chunk size. Default is `64`.
**Returns**:
- `ani`: The animation object.
**Note:**
- This function requires the `vape` volume renderer package.
"""
if trj.ndim != 5:
raise ValueError("trj must be a five-axis array.")

fig, ax = plt.subplots()

if include_init:
temporal_grid = jnp.arange(trj.shape[0])
else:
temporal_grid = jnp.arange(1, trj.shape[0] + 1)

if dt is not None:
temporal_grid *= dt

trj_wrapped = jax.vmap(wrap_bc)(trj)
trj_wrapped_no_channel = trj_wrapped[:, 0]

imgs = volume_render_state_3d(
trj_wrapped_no_channel,
vlim=vlim,
bg_color=bg_color,
resolution=resolution,
cmap=cmap,
transfer_function=transfer_function,
distance_scale=distance_scale,
gamma_correction=gamma_correction,
chunk_size=chunk_size,
**kwargs,
)

ax.imshow(imgs[0])
ax.axis("off")
ax.set_title(f"t = {temporal_grid[0]:.2f}")

def animate(i):
ax.clear()
ax.imshow(imgs[i])
ax.axis("off")
ax.set_title(f"t = {temporal_grid[i]:.2f}")

ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False)

plt.close(fig)

return ani


def animate_spatio_temporal_2d():
raise NotImplementedError("This function is not yet implemented.")
Loading

0 comments on commit 1476cd3

Please sign in to comment.