diff --git a/exponax/viz/_animate.py b/exponax/viz/_animate.py index 06457be..62ae9b4 100644 --- a/exponax/viz/_animate.py +++ b/exponax/viz/_animate.py @@ -5,7 +5,7 @@ from jaxtyping import Array, Float from matplotlib.animation import FuncAnimation -from ._plot import plot_spatio_temporal, plot_state_1d +from ._plot import plot_spatio_temporal, plot_state_1d, plot_state_2d N = TypeVar("N") @@ -149,17 +149,39 @@ def animate_spatial_temporal_facet( pass -def animate_state_2d(trj, *, vlim=(-1, 1)): +def animate_state_2d( + trj: Float[Array, "T 1 N N"], + *, + vlim: tuple[float, float] = (-1.0, 1.0), + domain_extent: float = None, + dt: float = None, + include_init: bool = False, + **kwargs, +): fig, ax = plt.subplots() - im = ax.imshow( - trj[0].squeeze().T, vmin=vlim[0], vmax=vlim[1], cmap="RdBu_r", origin="lower" + + if dt is not None: + time_range = (0, dt * trj.shape[0]) + if not include_init: + time_range = (dt, time_range[1]) + else: + time_range = (0, trj.shape[0] - 1) + + plot_state_2d( + trj[0], + vlim=vlim, + domain_extent=domain_extent, + ax=ax, ) - im.set_data(jnp.zeros_like(trj[0]).squeeze()) def animate(i): - im.set_data(trj[i].squeeze().T) - fig.suptitle(f"t_i = {i:04d}") - return im + ax.clear() + plot_state_2d( + trj[i], + vlim=vlim, + domain_extent=domain_extent, + ax=ax, + ) plt.close(fig) @@ -169,31 +191,49 @@ def animate(i): def animate_state_2d_facet( - trj, *, vlim=(-1, 1), grid=(3, 3), figsize=(10, 10), titles=None + trj: Union[Float[Array, "T C N N"], Float[Array, "B T 1 N N"]], + *, + facet_over_channels: bool = True, + vlim: tuple[float, float] = (-1.0, 1.0), + grid: tuple[int, int] = (3, 3), + figsize: tuple[float, float] = (10, 10), + titles=None, ): """ trj.shape = (n_trjs, n_timesteps, ...) """ + if facet_over_channels: + if trj.ndim != 4: + raise ValueError("trj must be a four-axis array.") + else: + if trj.ndim != 5: + raise ValueError("trj must be a five-axis array.") + + if facet_over_channels: + trj = jnp.swapaxes(trj, 0, 1) + trj = trj[:, :, None] + fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize) - im_s = [] - for i, ax in enumerate(ax_s.flatten()): - im = ax.imshow( - trj[i, 0].squeeze().T, - vmin=vlim[0], - vmax=vlim[1], - cmap="RdBu_r", - origin="lower", + + for j, ax in enumerate(ax_s.flatten()): + plot_state_2d( + trj[j, 0], + vlim=vlim, + ax=ax, ) - im.set_data(jnp.zeros_like(trj[i, 0]).squeeze()) - im_s.append(im) + if titles is not None: + ax.set_title(titles[j]) def animate(i): - for j, im in enumerate(im_s): - im.set_data(trj[j, i].squeeze().T) + for j, ax in enumerate(ax_s.flatten()): + ax.clear() + plot_state_2d( + trj[j, i], + vlim=vlim, + ax=ax, + ) if titles is not None: - ax_s.flatten()[j].set_title(titles[j]) - fig.suptitle(f"t_i = {i:04d}") - return im_s + ax.set_title(titles[j]) plt.close(fig)