Skip to content

Commit

Permalink
Facet 3d animation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Jun 11, 2024
1 parent 2213c1c commit 2966ec1
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
2 changes: 2 additions & 0 deletions exponax/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
animate_spatial_temporal_facet,
animate_state_1d_facet,
animate_state_2d_facet,
animate_state_3d_facet,
)
from ._plot import (
plot_spatio_temporal,
Expand Down Expand Up @@ -63,4 +64,5 @@
"plot_spatio_temporal_2d",
"animate_state_3d",
"plot_state_3d_facet",
"animate_state_3d_facet",
]
102 changes: 101 additions & 1 deletion exponax/viz/_animate_facet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import TypeVar, Union
from typing import Literal, TypeVar, Union

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jaxtyping import Array, Float
from matplotlib.animation import FuncAnimation

from .._utils import wrap_bc
from ._plot import plot_state_1d, plot_state_2d
from ._volume import volume_render_state_3d, zigzag_alpha

N = TypeVar("N")

Expand Down Expand Up @@ -246,3 +249,100 @@ def animate(i):
ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False)

return ani


def animate_state_3d_facet(
trj: Union[Float[Array, "T C N N N"], Float[Array, "B T 1 N N N"]],
*,
facet_over_channels: bool = True,
vlim: tuple[float, float] = (-1.0, 1.0),
grid: tuple[int, int] = (3, 3),
figsize: tuple[float, float] = (10, 10),
titles=None,
dt: float = None,
include_init: bool = False,
domain_extent: float = None,
bg_color: Union[
Literal["black"],
Literal["white"],
tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
] = "white",
resolution: int = 384,
cmap: str = "RdBu_r",
transfer_function: callable = zigzag_alpha,
distance_scale: float = 10.0,
gamma_correction: float = 2.4,
chunk_size: int = 64,
**kwargs,
):
if facet_over_channels:
if trj.ndim != 5:
raise ValueError("trj must be a five-axis array.")
else:
if trj.ndim != 6:
raise ValueError("trj must be a six-axis array.")

if facet_over_channels:
trj = jnp.swapaxes(trj, 0, 1)
trj = trj[:, :, None]

trj_wrapped = jax.vmap(jax.vmap(wrap_bc))(trj)

imgs = []
for facet_entry_trj in trj_wrapped:
facet_entry_trj_no_channel = facet_entry_trj[:, 0]
imgs.append(
volume_render_state_3d(
facet_entry_trj_no_channel,
vlim=vlim,
domain_extent=domain_extent,
bg_color=bg_color,
resolution=resolution,
cmap=cmap,
transfer_function=transfer_function,
distance_scale=distance_scale,
gamma_correction=gamma_correction,
chunk_size=chunk_size,
**kwargs,
)
)

# shape = (B, T, resolution, resolution, 3)
imgs = jnp.stack(imgs)

if include_init:
temporal_grid = jnp.arange(trj.shape[1])
else:
temporal_grid = jnp.arange(1, trj.shape[1] + 1)

if dt is not None:
temporal_grid *= dt

fig, ax_s = plt.subplots(*grid, figsize=figsize)

# num_subplots = trj.shape[0]

for j, ax in enumerate(ax_s.flatten()):
ax.imshow(imgs[j, 0])
ax.axis("off")
# if j >= num_subplots:
# ax.remove()
# else:
if titles is not None:
ax.set_title(titles[j])
title = fig.suptitle(f"t = {temporal_grid[0]:.2f}")

def animate(i):
for j, ax in enumerate(ax_s.flatten()):
ax.clear()
ax.imshow(imgs[j, i])
ax.axis("off")
if titles is not None:
ax.set_title(titles[j])
title.set_text(f"t = {temporal_grid[i]:.2f}")

ani = FuncAnimation(fig, animate, frames=trj.shape[1], interval=100, blit=False)

plt.close(fig)

return ani

0 comments on commit 2966ec1

Please sign in to comment.