Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ax arg to plot method #44

Merged
merged 18 commits into from
Jul 24, 2024
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions brainglobe_heatmap/heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,13 @@ def render(self, **kwargs) -> Scene:
def plot(
self,
show_legend: bool = False,
xlabel: str = "μm",
ylabel: str = "μm",
xlabel: str = "µm",
ylabel: str = "µm",
hide_axes: bool = False,
filename: Optional[str] = None,
cbar_label: Optional[str] = None,
ax: Optional[plt.Axes] = None,
show_cbar: bool = True,
**kwargs,
) -> plt.Figure:
"""
Expand All @@ -187,7 +189,11 @@ def plot(
self.regions_meshes, self.scene.root
)

f, ax = plt.subplots(figsize=(9, 9))
if ax is None:
f, ax = plt.subplots(figsize=(9, 9))
else:
f = plt.gcf()

for r, coords in projected.items():
name, segment = r.split("_segment_")
ax.fill(
Expand All @@ -201,30 +207,33 @@ def plot(
alpha=0.3 if name == "root" else None,
)

# make colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

# cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)
if self.label_regions is True:
cbar = f.colorbar(
mpl.cm.ScalarMappable(
norm=None,
cmap=mpl.cm.get_cmap(self.cmap, len(self.values)),
),
cax=cax,
)
else:
cbar = f.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap=self.cmap), cax=cax
)
if show_cbar:
# make colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

# cmap = mpl.cm.cool
norm = mpl.colors.Normalize(vmin=self.vmin, vmax=self.vmax)
if self.label_regions is True:
cbar = f.colorbar(
mpl.cm.ScalarMappable(
norm=None,
cmap=mpl.cm.get_cmap(self.cmap, len(self.values)),
),
cax=cax,
)
else:
cbar = f.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap=self.cmap), cax=cax
)

if cbar_label is not None:
cbar.set_label(cbar_label)
if cbar_label is not None:
cbar.set_label(cbar_label)

if self.label_regions is True:
cbar.ax.set_yticklabels([r.strip() for r in self.values.keys()])
if self.label_regions is True:
cbar.ax.set_yticklabels(
[r.strip() for r in self.values.keys()]
)

# style axes
ax.invert_yaxis()
Expand All @@ -250,6 +259,6 @@ def plot(

if show_legend:
ax.legend()
plt.show()
# plt.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks backwards compatibility right?
Could this be made to still behave as before (unless you pass the new arguments), please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, thanks for the quick answer.

Copy link
Contributor Author

@RobertoDF RobertoDF May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import matplotlib.pyplot as plt
import numpy as np
import brainglobe_heatmap as bgh

data_dict={
 'VISpm': 0.0,
 'VISp': 0.14285714285714285,
 'VISl': 0.2857142857142857,
 'VISli': 0.42857142857142855,
  'VISal': 0.5714285714285714,
 'VISrl': 0.7142857142857142,
 'SSp-bfd': 0.8571428571428571,
 'VISam': 1.0
}
fig, axs = plt.subplots(3,2, figsize=(18,12))
show_cbar = False

for ax, d in zip(axs.flatten(),range(7500,10500,500)):

    scene = bgh.Heatmap(
        data_dict,
        position=(d),
        orientation="frontal",  # or 'sagittal', or 'horizontal' or a tuple (x,y,z)
        thickness=10,
        format="2D",
        cmap="Set2",
        vmin=0,
        vmax=1,
        label_regions=True
    )
    scene.plot(ax=ax, show_cbar=show_cbar, hide_axes=True)

plt.tight_layout()
plt.show()

Does this work?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

I can see the figure in a fresh conda environment with Python 3.10, but not in fresh conda environments with Python 3.11 or 3.12. In all cases, I get a segmentation fault (after closing the figure in 3.10).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am on Ubuntu 22.04

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even more weird: if i define the plotting function outside jupyter (so I import it with from Utils.Utils import plot_brainrender_heatmaps) everything works perfectly without crashing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey all! I ran into the same seg fault on Ubuntu. A small change in the example script fixed it for me:

import matplotlib.pyplot as plt
import numpy as np
import brainglobe_heatmap as bgh

data_dict={
 'VISpm': 0.0,
 'VISp': 0.14285714285714285,
 'VISl': 0.2857142857142857,
 'VISli': 0.42857142857142855,
  'VISal': 0.5714285714285714,
 'VISrl': 0.7142857142857142,
 'SSp-bfd': 0.8571428571428571,
 'VISam': 1.0
}
show_cbar = True

fig, axs = plt.subplots(3, 2, figsize=(18, 12))
scenes = []

for ax, d in zip(axs.flatten(),range(7500,10500,500)):
    scene = bgh.Heatmap(
        data_dict,
        position=(d),
        orientation="frontal",  # or 'sagittal', or 'horizontal' or a tuple (x,y,z)
        thickness=10,
        format="2D",
        cmap="Reds",
        vmin=0,
        vmax=1,
        label_regions=False
    )
    scene.plot_subplot(fig=fig, ax=ax, show_cbar=show_cbar, hide_axes=False)
    scenes.append(scene)

plt.tight_layout()
plt.show()

My best guess is that matplotlib does something lazily when plotting, and the garbage collector was being really eager and freeing the memory for each scene before plt.show() was called. Keeping a reference to each scene active fixed the seg fault for me!

I'm not sure if this is something we should address within the implementation or if it's best to publish this example script somewhere in the docs as the "proper" way of doing it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a sense of how easy it would be to address this in the implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be addressed by somehow changing Heatmap to allow storing multiple slicers internally, that way instead of this being done in a for loop you'd initialize one scene with multiple 2D positions. Then the plot() function can have internal logic to check if multiple slicers are present and provide a figure with a subplot for each slicer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realised I didn't actually answer your question! In terms of ease I'd say it wouldn't be too bad, just a bit more refactoring. The main thing I'd worry about is adding too many conditionals, or how to expose this to the user without overloading the position argument in the __init__ further!


return f
return ax
Loading