Skip to content

Commit

Permalink
Merge pull request #44 from RobertoDF/ax-argument-to-heatmap
Browse files Browse the repository at this point in the history
Add ax arg to plot method
  • Loading branch information
IgorTatarnikov committed Jul 24, 2024
2 parents 9409304 + e8610a2 commit 743a26c
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 32 deletions.
171 changes: 139 additions & 32 deletions brainglobe_heatmap/heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Heatmap:
def __init__(
self,
values: Dict,
position: Union[list, tuple, np.ndarray],
position: Union[list, tuple, np.ndarray, float],
orientation: Union[str, tuple] = "frontal",
hemisphere: str = "both",
title: Optional[str] = None,
Expand Down Expand Up @@ -184,21 +184,129 @@ def render(self, camera=None) -> Scene:
def plot(
self,
show_legend: bool = False,
xlabel: str = "μm",
ylabel: str = "μm",
xlabel: str = "µm",
ylabel: str = "µm",
hide_axes: bool = False,
filename: Optional[str] = None,
cbar_label: Optional[str] = None,
show_cbar: bool = True,
**kwargs,
) -> plt.Figure:
"""
Plots the heatmap in 2D using matplotlib
Plots the heatmap in 2D using matplotlib.
This method generates a 2D visualization of the heatmap data in
a standalone matplotlib figure.
Parameters
----------
show_legend : bool, optional
If True, displays a legend for the plotted regions.
Default is False.
xlabel : str, optional
Label for the x-axis. Default is "µm".
ylabel : str, optional
Label for the y-axis. Default is "µm".
hide_axes : bool, optional
If True, hides the axes for a cleaner look. Default is False.
filename : Optional[str], optional
Path to save the figure to. If None, the figure is not saved.
Default is None.
cbar_label : Optional[str], optional
Label for the colorbar. If None, no label is displayed.
Default is None.
show_cbar : bool, optional
If True, displays a colorbar alongside the subplot.
Default is True.
**kwargs : dict
Additional keyword arguments passed to the plotting function.
Returns
-------
plt.Figure
The matplotlib figure object for the plot.
Notes
-----
This method is used to generate a standalone plot of
the heatmap data.
"""

f, ax = plt.subplots(figsize=(9, 9))

f, ax = self.plot_subplot(
fig=f,
ax=ax,
show_legend=show_legend,
xlabel=xlabel,
ylabel=ylabel,
hide_axes=hide_axes,
cbar_label=cbar_label,
show_cbar=show_cbar,
**kwargs,
)

if filename is not None:
plt.savefig(filename, dpi=300)

plt.show()
return f

def plot_subplot(
self,
fig: plt.Figure,
ax: plt.Axes,
show_legend: bool = False,
xlabel: str = "µm",
ylabel: str = "µm",
hide_axes: bool = False,
cbar_label: Optional[str] = None,
show_cbar: bool = True,
**kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Plots a heatmap in a subplot within a given figure and axes.
This method is responsible for plotting a single subplot within a
larger figure, allowing for the creation of complex multi-plot
visualizations.
Parameters
----------
fig : plt.Figure, optional
The figure object in which the subplot is plotted.
ax : plt.Axes, optional
The axes object in which the subplot is plotted.
show_legend : bool, optional
If True, displays a legend for the plotted regions.
Default is False.
xlabel : str, optional
Label for the x-axis. Default is "µm".
ylabel : str, optional
Label for the y-axis. Default is "µm".
hide_axes : bool, optional
If True, hides the axes for a cleaner look. Default is False.
cbar_label : Optional[str], optional
Label for the colorbar. If None, no label is displayed.
Default is None.
show_cbar : bool, optional
Display a colorbar alongside the subplot. Default is True.
**kwargs : dict
Additional keyword arguments passed to the plotting function.
Returns
-------
plt.Figure, plt.Axes
A tuple containing the figure and axes objects used for the plot.
Notes
-----
This method modifies the provided figure and axes objects in-place.
"""
projected, _ = self.slicer.get_structures_slice_coords(
self.regions_meshes, self.scene.root
)

f, ax = plt.subplots(figsize=(9, 9))
for r, coords in projected.items():
name, segment = r.split("_segment_")
ax.fill(
Expand All @@ -212,30 +320,33 @@ def plot(
alpha=0.3 if name == "root" else None,
)

# make colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

# cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)
if self.label_regions is True:
cbar = f.colorbar(
mpl.cm.ScalarMappable(
norm=None,
cmap=mpl.cm.get_cmap(self.cmap, len(self.values)),
),
cax=cax,
)
else:
cbar = f.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap=self.cmap), cax=cax
)
if show_cbar:
# make colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

# cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)
if self.label_regions is True:
cbar = fig.colorbar(
mpl.cm.ScalarMappable(
norm=None,
cmap=mpl.cm.get_cmap(self.cmap, len(self.values)),
),
cax=cax,
)
else:
cbar = fig.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap=self.cmap), cax=cax
)

if cbar_label is not None:
cbar.set_label(cbar_label)
if cbar_label is not None:
cbar.set_label(cbar_label)

if self.label_regions is True:
cbar.ax.set_yticklabels([r.strip() for r in self.values.keys()])
if self.label_regions is True:
cbar.ax.set_yticklabels(
[r.strip() for r in self.values.keys()]
)

# style axes
ax.invert_yaxis()
Expand All @@ -256,11 +367,7 @@ def plot(
ax.set_yticks([])
ax.set(xlabel="", ylabel="")

if filename is not None:
plt.savefig(filename, dpi=300)

if show_legend:
ax.legend()
plt.show()

return f
return fig, ax
40 changes: 40 additions & 0 deletions examples/heatmap_2d_subplots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import matplotlib.pyplot as plt

import brainglobe_heatmap as bgh

data_dict = {
"VISpm": 0.0,
"VISp": 0.14285714285714285,
"VISl": 0.2857142857142857,
"VISli": 0.42857142857142855,
"VISal": 0.5714285714285714,
"VISrl": 0.7142857142857142,
"SSp-bfd": 0.8571428571428571,
"VISam": 1.0,
}

# Create a list of scenes to plot
# Note: it's important to keep reference to the scenes to avoid a
# segmentation fault
scenes = []
for distance in range(7500, 10500, 500):
scene = bgh.Heatmap(
data_dict,
position=distance,
orientation="frontal",
thickness=10,
format="2D",
cmap="Reds",
vmin=0,
vmax=1,
label_regions=False,
)
scenes.append(scene)

# Create a figure with 6 subplots and plot the scenes
fig, axs = plt.subplots(3, 2, figsize=(18, 12))
for scene, ax in zip(scenes, axs.flatten(), strict=False):
scene.plot_subplot(fig=fig, ax=ax, show_cbar=True, hide_axes=False)

plt.tight_layout()
plt.show()

0 comments on commit 743a26c

Please sign in to comment.