Skip to content

Commit

Permalink
Merge pull request #733 from eugenioLR/master
Browse files Browse the repository at this point in the history
Automatic aspect ratio calculation
  • Loading branch information
SarthakJariwala authored Sep 3, 2024
2 parents 0a9c87e + 761f4e1 commit a3fe0ef
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
37 changes: 27 additions & 10 deletions src/seaborn_image/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class ImageGrid:
Number of columns to display. Defaults to None.
height : int or float, optional
Size of the individual images. Defaults to 3.
aspect : int or float, optional
Aspect ratio of individual images. Defaults to 1.
aspect : int, float or 'auto', optional
Aspect ratio of individual images, when set to 'auto', it calculates the aspect ratio of the images passed. Defaults to 'auto'.
cmap : str or `matplotlib.colors.Colormap` or list, optional
Image colormap. If input data is a list of images,
`cmap` can be a list of colormaps. Defaults to None.
Expand Down Expand Up @@ -290,7 +290,8 @@ class ImageGrid:
... [pol, pl, retina],
... map_func=[gaussian, median, hessian],
... dx=[15, 100, None],
... units="nm")
... units="nm",
... aspect=1)
Change colorbar orientation
Expand Down Expand Up @@ -321,7 +322,7 @@ def __init__(
map_func_kw=None,
col_wrap=None,
height=3,
aspect=1,
aspect="auto",
cmap=None,
robust=False,
perc=(2, 98),
Expand Down Expand Up @@ -365,6 +366,10 @@ def __init__(
len(map_func) if len(map_func) >= len(data) else len(data)
)

if aspect == "auto":
aspect_ratios = [img.shape[1] / img.shape[0] for img in data]
aspect = min(aspect_ratios)

elif not isinstance(data, np.ndarray):
raise ValueError("image data must be a list of images or a 3d or 4d array.")

Expand All @@ -385,6 +390,9 @@ def __init__(
# no of columns should now be len of map_func list
col_wrap = len(map_func) if col_wrap is None else col_wrap

if aspect == "auto":
aspect = data.shape[1] / data.shape[0]

elif data.ndim in [3, 4]:
if data.ndim == 4 and data.shape[-1] not in [1, 3, 4]:
raise ValueError(
Expand Down Expand Up @@ -418,6 +426,12 @@ def __init__(

_nimages = len(slices)

if aspect == "auto":
# Select axis where width and height are
height_idx, width_idx = [i for i in range(data.ndim) if i != (axis % data.ndim)][:2]

aspect = data.shape[width_idx] / data.shape[height_idx]

# ---- 3D or 4D image with an individual map_func ----
map_func_type = self._check_map_func(map_func, map_func_kw)
# raise a ValueError if a list of map_func is provided for 3d image
Expand Down Expand Up @@ -765,7 +779,7 @@ def rgbplot(
*,
col_wrap=3,
height=3,
aspect=1,
aspect="auto",
cmap=None,
alpha=None,
origin=None,
Expand Down Expand Up @@ -794,8 +808,8 @@ def rgbplot(
Number of columns to display. Defaults to 3.
height : int or float, optional
Size of the individual images. Defaults to 3.
aspect : int or float, optional
Aspect ratio of individual images. Defaults to 1.
aspect : int, float or 'auto', optional
Aspect ratio of individual images, when set to 'auto', it calculates the aspect ratio of the images passed. Defaults to 'auto'.
cmap : str or `matplotlib.colors.Colormap` or list, optional
Image colormap or a list of colormaps. Defaults to None.
alpha : float or array-like, optional
Expand Down Expand Up @@ -975,8 +989,8 @@ class ParamGrid(object):
is not None and `row` is None. Defaults to None.
height : int or float, optional
Size of the individual images. Defaults to 3.
aspect : int or float, optional
Aspect ratio of individual images. Defaults to 1.
aspect : int, float or 'auto', optional
Aspect ratio of individual images, when set to 'auto', it calculates the aspect ratio of the images passed. Defaults to 'auto'.
cmap : str or `matplotlib.colors.Colormap`, optional
Image colormap. Defaults to None.
alpha : float or array-like, optional
Expand Down Expand Up @@ -1113,7 +1127,7 @@ def __init__(
col=None,
col_wrap=None,
height=3,
aspect=1,
aspect="auto",
cmap=None,
alpha=None,
origin=None,
Expand Down Expand Up @@ -1173,6 +1187,9 @@ def __init__(
ncol = col_wrap
nrow = int(np.ceil(len(kwargs[f"{col}"]) / col_wrap))

if aspect == "auto":
aspect = data.shape[1]/data.shape[0]

# Calculate the base figure size
figsize = (ncol * height * aspect, nrow * height)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,28 @@ def test_figure_size(self):
)
np.testing.assert_array_equal(g4.fig.get_size_inches(), (3 * 2 * 1.5, 2 * 2))
plt.close()

def test_auto_aspect(self):
imgsize0 = (10, 10)
g0 = isns.ImageGrid([np.zeros(imgsize0) for i in range(10)], aspect='auto')
assert np.isclose(imgsize0[1]/imgsize0[0], g0.aspect)
plt.close()

imgsize1 = (10, 5)
g1 = isns.ImageGrid([np.zeros(imgsize1) for i in range(10)], aspect='auto')
assert np.isclose(imgsize1[1]/imgsize1[0], g1.aspect)
plt.close()

imgsize2 = (5, 10)
g2 = isns.ImageGrid([np.zeros(imgsize2) for i in range(10)], aspect='auto')
assert np.isclose(imgsize2[1]/imgsize2[0], g2.aspect)
plt.close()

imglist = [np.zeros(imgsize0) for i in range(4)] + [np.zeros(imgsize1) for i in range(4)] + [np.zeros(imgsize2) for i in range(4)]
g3 = isns.ImageGrid(imglist, aspect='auto')
assert np.isclose(min([imgsize0[1]/imgsize0[0], imgsize1[1]/imgsize1[0], imgsize2[1]/imgsize2[0]]), g3.aspect)
plt.close()

def test_vmin_vmax(self):
g = isns.ImageGrid(cells, vmin=0.5, vmax=0.75)
for ax in g.axes.ravel():
Expand Down

0 comments on commit a3fe0ef

Please sign in to comment.