Skip to content

Commit

Permalink
Improve plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 22, 2024
1 parent 230e34f commit 404ec22
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 15 deletions.
11 changes: 11 additions & 0 deletions exponax/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
You do not have to use them, as all states are pure jax arrays, plotting with
any library is straightforward.
Supported visualization methods:
- Display 1d states as line plots
- Display 1d trajectories as spatio-temporal image plots
- Display 2d states as image plots
All the methods also have a `facet` version, which allows you to plot multiple
states at once.
All plotting routines (three main routines and their three facet counterparts)
can be animated over another axis (some notion of time).
"""

from ._animate import animate_state_1d, animate_state_2d, animate_state_2d_facet
Expand Down
99 changes: 96 additions & 3 deletions exponax/viz/_animate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import TypeVar
from typing import TypeVar, Union

import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float
from matplotlib.animation import FuncAnimation

from ._plot import plot_state_1d
from ._plot import plot_spatio_temporal, plot_state_1d

N = TypeVar("N")


def animate_state_1d(
trj: Float[Array, "T B N"],
trj: Float[Array, "T C N"],
*,
vlim: tuple[float, float] = (-1, 1),
domain_extent: float = None,
Expand Down Expand Up @@ -56,6 +56,99 @@ def animate(i):
return ani


def animate_state_1d_facet(
trj: Float[Array, "T B C N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
labels: list[str] = None,
titles: list[str] = None,
domain_extent: float = None,
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
**kwargs,
):
if trj.ndim != 4:
raise ValueError("states must be a four-axis array.")

fig, ax_s = plt.subplots(*grid, figsize=figsize)

for i, ax in enumerate(ax_s.flatten()):
plot_state_1d(
trj[i],
vlim=vlim,
domain_extent=domain_extent,
labels=labels,
ax=ax,
**kwargs,
)
if titles is not None:
ax.set_title(titles[i])

return fig


def animate_spatio_temporal(
trjs: Float[Array, "S T C N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
**kwargs,
):
fig, ax = plt.subplots()

plot_spatio_temporal(
trjs[0],
vlim=vlim,
domain_extent=domain_extent,
dt=dt,
include_init=include_init,
ax=ax,
**kwargs,
)

def animate(i):
ax.clear()
plot_spatio_temporal(
trjs[i],
vlim=vlim,
domain_extent=domain_extent,
dt=dt,
include_init=include_init,
ax=ax,
**kwargs,
)

plt.close(fig)

ani = FuncAnimation(fig, animate, frames=trjs.shape[0], interval=100, blit=False)

return ani


def animate_spatial_temporal_facet(
trjs: Union[Float[Array, "S T C N"], Float[Array, "B S T 1 N"]],
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
domain_extent: float = None,
dt: float = None,
include_init: bool = False,
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
**kwargs,
):
if facet_over_channels:
if trjs.ndim != 4:
raise ValueError("trjs must be a four-axis array.")
else:
if trjs.ndim != 5:
raise ValueError("states must be a five-axis array.")
# TODO
pass


def animate_state_2d(trj, *, vlim=(-1, 1)):
fig, ax = plt.subplots()
im = ax.imshow(
Expand Down
29 changes: 17 additions & 12 deletions exponax/viz/_plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import TypeVar, Union

import jax
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -75,7 +75,7 @@ def plot_state_1d_facet(


def plot_spatio_temporal(
trj: Float[Array, "T N"],
trj: Float[Array, "T 1 N"],
*,
vlim: tuple[float, float] = (-1.0, 1.0),
ax=None,
Expand All @@ -84,10 +84,8 @@ def plot_spatio_temporal(
include_init: bool = False,
**kwargs,
):
if trj.ndim != 2:
raise ValueError(
"trj must be a two-axis array. Extract the channel you want to plot."
)
if trj.ndim != 3:
raise ValueError("trj must be a two-axis array.")

trj_wrapped = jax.vmap(wrap_bc)(trj)

Expand Down Expand Up @@ -124,8 +122,9 @@ def plot_spatio_temporal(


def plot_spatio_temporal_facet(
trjs: Float[Array, "B T N"],
trjs: Union[Float[Array, "T C N"], Float[Array, "B T 1 N"]],
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
Expand All @@ -135,16 +134,22 @@ def plot_spatio_temporal_facet(
include_init: bool = False,
**kwargs,
):
if trjs.ndim != 3:
raise ValueError(
"trjs must be a three-axis array. Extract the channel you want to plot."
)
if facet_over_channels:
if trjs.ndim != 3:
raise ValueError("trjs must be a three-axis array.")
else:
if trjs.ndim != 4:
raise ValueError("trjs must be a four-axis array.")

fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize)

for i, ax in enumerate(ax_s.flatten()):
if facet_over_channels:
single_trj = trjs[:, i : i + 1]
else:
single_trj = trjs[i]
plot_spatio_temporal(
trjs[i],
single_trj,
vlim=vlim,
ax=ax,
domain_extent=domain_extent,
Expand Down

0 comments on commit 404ec22

Please sign in to comment.