diff --git a/iblatlas/plots.py b/iblatlas/plots.py index 92738b0..44bb2f0 100644 --- a/iblatlas/plots.py +++ b/iblatlas/plots.py @@ -828,7 +828,8 @@ def plot_scalar_on_barplot(acronyms, values, errors=None, order=True, ax=None, b def plot_swanson_vector(acronyms=None, values=None, ax=None, hemisphere=None, br=None, orientation='landscape', empty_color='silver', vmin=None, vmax=None, cmap='viridis', annotate=False, annotate_n=10, - annotate_order='top', annotate_list=None, mask=None, mask_color='w', fontsize=10, **kwargs): + annotate_order='top', annotate_list=None, mask=None, mask_color='w', fontsize=10, + show_cbar=False, extend='neither', **kwargs): """ Function to plot scalar value per allen region on the swanson projection. Plots on a vecortised version of the swanson projection @@ -854,8 +855,12 @@ def plot_swanson_vector(acronyms=None, values=None, ax=None, hemisphere=None, br Minimum value to restrict the colormap vmax: float Maximum value to restrict the colormap - cmap: string - matplotlib named colormap to use + cmap: string or matplotlib.colors.Colormap + matplotlib colormap to use + show_cbar: bool, default=False + Whether to display a colorbar. + extend: str, default='neither' + Which side of the colorbar to extend. See `colorbar` documentation. annotate : bool, default=False If true, labels the regions with acronyms. annotate_n: int @@ -886,6 +891,8 @@ def plot_swanson_vector(acronyms=None, values=None, ax=None, hemisphere=None, br if ax is None: fig, ax = plt.subplots() ax.set_axis_off() + else: + fig = ax.get_figure() if hemisphere != 'both' and acronyms is not None and not isinstance(acronyms[0], str): # If negative atlas ids are passed in and we are not going to lateralise (e.g hemisphere='both') @@ -894,11 +901,22 @@ def plot_swanson_vector(acronyms=None, values=None, ax=None, hemisphere=None, br if acronyms is not None: ibr, vals = br.propagate_down(acronyms, values) - colormap = matplotlib.colormaps.get_cmap(cmap) - vmin = vmin or np.nanmin(vals) - vmax = vmax or np.nanmax(vals) + + if isinstance(cmap, matplotlib.colors.Colormap): + colormap = cmap + elif isinstance(cmap, str): + colormap = matplotlib.colormaps.get_cmap(cmap) + else: + raise ValueError("`cmap` option must be of type `str` or `matplotlib.colors.Colormap`") + + vmin = vmin if vmin is not None else np.nanmin(vals) + vmax = vmax if vmax is not None else np.nanmax(vals) norm = colors.Normalize(vmin=vmin, vmax=vmax) rgba_color = colormap(norm(vals), bytes=True) + if show_cbar: + fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), + ax=ax, orientation='vertical', extend=extend, + ) if mask is not None: imr, _ = br.propagate_down(mask, np.ones_like(mask))