Skip to content

Commit

Permalink
tweaked and regenerated plots
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 30, 2024
1 parent eabf8b0 commit 98a0933
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 46 deletions.
145 changes: 111 additions & 34 deletions brainglobe_template_builder/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,74 +188,151 @@ def plot_slices_single_column(

def pad_with_zeros(
stack: np.ndarray, target: int = 512
) -> tuple[np.ndarray, list[int]]:
) -> tuple[np.ndarray, list[tuple[int, int]]]:
"""Pad the stack with zeros to reach the target size in all dimensions."""
pad_sizes = [(target - s) // 2 for s in stack.shape]
pad_sizes = []
for s in stack.shape:
delta = target - s
if delta < 0:
raise ValueError(
f"Cannot pad dimension of size {s} to a "
f"smaller target size {target}."
)
left_pad = delta // 2
right_pad = delta - left_pad
pad_sizes.append((left_pad, right_pad))

padded_stack = np.pad(
stack,
(
(pad_sizes[0], pad_sizes[0]),
(pad_sizes[1], pad_sizes[1]),
(pad_sizes[2], pad_sizes[2]),
),
pad_sizes,
mode="constant",
)
return padded_stack, pad_sizes


def plot_orthographic(
img: np.ndarray,
show_slices: list[int],
slice_label_offset: int = 0,
pad_sizes: list[int] | None = None,
show_slices: list[int] | None = None,
pad_sizes: list[tuple[int, int]] | None = None,
mip_attenuation: float = 0.01,
save_path: Path | None = None,
) -> tuple[plt.Figure, np.ndarray]:
"""Plot orthographic views of a 3D image,
including a maximum intensity projection (MIP)."""

sc = AnatomicalSpace("ASR")
sc = AnatomicalSpace("ASR", shape=img.shape)
if pad_sizes is None:
pad_sizes = [0, 0, 0]
max_size = max(img.shape)
pad_sizes = [(0, 0)] * img.ndim

# Default to the middle slice along each axis
if show_slices is None:
show_slices = [img.shape[i] // 2 for i in range(img.ndim)]

fig, axs = plt.subplots(1, 4, figsize=(14, 4))
sections = [s.capitalize() for s in sc.sections] + ["MIP"]
axis_labels = [*sc.axis_labels, sc.axis_labels[1]]
frames = [
img.take(slc + pad_sizes[i], axis=i)
for i, slc in enumerate(show_slices)
]
mip = np.max(img, axis=1)

v_axis = sc.get_axis_idx("vertical")
mip, _ = _compute_attenuated_mip(img, v_axis, mip_attenuation)
sections = [s.capitalize() for s in sc.sections] + ["Top view"]
axis_labels = [*sc.axis_labels, sc.axis_labels[v_axis]]
slice_idxs = [s + p[0] for s, p in zip(show_slices, pad_sizes)]
frames = [img.take(s, axis=i) for i, s in enumerate(slice_idxs)]
frames.append(mip)
slice_labels = [slc + slice_label_offset for slc in show_slices]
slice_texts = [f"Slice {slc}" for slc in slice_labels] + [""]

for j, (view, labels) in enumerate(zip(sections, axis_labels)):
ax = axs[j]
ax.imshow(
frames[j],
cmap="gray",
aspect="equal",
origin="upper",
vmin=np.percentile(img, 1),
vmax=np.percentile(img, 99.9),
)
ax.set_title(view)
ax.text(
max_size / 2,
max_size / 20,
slice_texts[j],
ha="center",
va="top",
color="w",
)
ax.set_ylabel(labels[0])
ax.set_xlabel(labels[1])
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)

# Indicate location of orthogonal slices with lines
if j < 3:
h, v = sc.index_pairs[j]
ax.axhline(slice_idxs[h], color="r", linestyle="--", alpha=0.5)
ax.axvline(slice_idxs[v], color="r", linestyle="--", alpha=0.5)

ax = _clear_spines_and_ticks(ax)

fig.subplots_adjust(
left=0.025, right=0.975, top=0.95, bottom=0.05, wspace=0.1, hspace=0
)
if save_path:
save_figure(fig, save_path.parent, save_path.name.split(".")[0])
return fig, axs


def _compute_attenuated_mip(
img: np.ndarray, axis: int, attenuation_factor: float
) -> tuple[np.ndarray, str]:
"""Compute the maximum intensity projection (MIP) with attenuation.
If the image is zero-padded, attenuation is only applied within the
non-zero region along the specified axis.
Parameters
----------
img : np.ndarray
Image volume.
axis : int
Axis along which to compute the MIP.
attenuation_factor : float
Attenuation factor for the MIP. 0 means no attenuation.
Returns
-------
tuple[np.ndarray, str]
MIP image and label. The label is "MIP" if no attenuation is applied,
and "MIP (attenuated)" otherwise.
"""

mip_label = "MIP"

if attenuation_factor < 0:
raise ValueError("Attenuation factor must be non-negative.")

if attenuation_factor < 1e-6:
# If the factor is too small, skip attenuation
mip = np.max(img, axis=axis)
return mip, mip_label

# Find the non-zero bounding box along the specified axis
other_axes = tuple(i for i in range(img.ndim) if i != axis)
non_zero_mask = np.any(img != 0, axis=other_axes)
non_zero_indices = np.nonzero(non_zero_mask)[0]
start, end = non_zero_indices[0], non_zero_indices[-1] + 1

# Trim the image along the attenuation axis (get rid of zero-padding)
slices = [slice(None)] * img.ndim
slices[axis] = slice(start, end)
trimmed_img = img[tuple(slices)]

# Apply attenuation to the trimmed image
attenuation = np.exp(
-attenuation_factor * np.arange(trimmed_img.shape[axis])
)
attenuation_shape = [1] * trimmed_img.ndim
attenuation_shape[axis] = trimmed_img.shape[axis]
attenuation = attenuation.reshape(attenuation_shape)
attenuated_img = trimmed_img.astype(np.float32) * attenuation

# Compute and return the attenuated MIP
mip = np.max(attenuated_img, axis=axis)
mip_label += " (attenuated)"

return mip, mip_label


def _clear_spines_and_ticks(ax: plt.Axes) -> plt.Axes:
"""Clear spines and ticks from a matplotlib axis."""
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
return ax
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,11 @@ num_iterations: 4

# Slice indices are for NiFTI images as output by ANTs
# Order: sagittal, coronal, axial
show_slices: [256, 184, 252] # coronal, transverse, sagittal
# The offset is necessary to translate the above indices
# to the packaged reference space, becasue of the padding
# added during post-processing.
# This offset is only used for printing slide indices in plots
slice_label_offset: 6
show_slices: [256, 184, 244] # coronal, transverse, sagittal

vmin_percentile: 1
vmax_percentile: 99.9
mip_attenuation: 0.01
animation_fps: [1, 2, 4, 8]
use4template_dir_suffix: "orig-asr_N4_aligned_padded_use4template"
example_subjects: ["sub-BC41o", "sub-BC63", "sub-BC71"]
17 changes: 17 additions & 0 deletions examples/plots/config_50um.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
atlas_forge_dir: "/media/ceph-niu/neuroinformatics/atlas-forge"
species: "BlackCap"
template_name: "template_sym_res-50um_n-18_average-trimean"
resolution_um: 50
transform_types: ["rigid", "similarity", "affine", "nlin"]
num_iterations: 4

# Slice indices are for NiFTI images as output by ANTs
# Order: sagittal, coronal, axial
show_slices: [128, 92, 122] # coronal, transverse, sagittal

vmin_percentile: 1
vmax_percentile: 99.9
mip_attenuation: 0.02
animation_fps: [1, 2, 4, 8]
use4template_dir_suffix: "orig-asr_N4_aligned_padded_use4template"
example_subjects: ["sub-BC41o", "sub-BC63", "sub-BC71"]
12 changes: 7 additions & 5 deletions examples/plots/template_and_individual.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Load matplotlib parameters (to allow for proper font export)
plt.style.use(current_dir / "plots.mplstyle")
# Load config file containing template building parameters
config = load_config(current_dir / "config.yaml")
config = load_config(current_dir / "config_25um.yaml")

# Setup directories based on config file
atlas_dir, template_dir, plots_dir = setup_directories(config)
Expand Down Expand Up @@ -97,9 +97,9 @@
# Plot the final template in orthographic view
fig, axs = plot_orthographic(
template_img,
config["show_slices"],
slice_label_offset=config["slice_label_offset"],
show_slices=config["show_slices"],
pad_sizes=pad_sizes,
mip_attenuation=config["mip_attenuation"],
save_path=plots_dir / "final_template_orthographic",
)
print("Plotted final template in orthographic view")
Expand All @@ -112,9 +112,11 @@

fig, axs = plot_orthographic(
subject_img,
config["show_slices"],
slice_label_offset=config["slice_label_offset"],
show_slices=config["show_slices"],
pad_sizes=pad_sizes,
mip_attenuation=config["mip_attenuation"],
save_path=plots_dir / f"{example_subject}_orthographic",
)
print("Plotted example subjects in orthographic view")

# %%
2 changes: 1 addition & 1 deletion examples/plots/template_building_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Load matplotlib parameters (to allow for proper font export)
plt.style.use(current_dir / "plots.mplstyle")
# Load config file containing template building parameters
config = load_config(current_dir / "config.yaml")
config = load_config(current_dir / "config_25um.yaml")

# Setup directories based on config file
atlas_dir, template_dir, plots_dir = setup_directories(config)
Expand Down

0 comments on commit 98a0933

Please sign in to comment.