Skip to content

Commit

Permalink
Update visualization methods
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewFilipovich committed Dec 9, 2024
1 parent e8c52a1 commit 3780592
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 96 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ measurements = [

# Visualize the measured intensity distributions
for i, measurement in enumerate(measurements):
measurement.visualize(title=f"z={i}f", vmax=1, intensity=True)
measurement.visualize(title=f"z={i}f", vmax=1)
```

<p align="center">
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ This example demonstrates simulating a 4f imaging system using TorchOptics. The
# Visualize the measured intensity distributions
for i, measurement in enumerate(measurements):
measurement.visualize(title=f"z={i}f", vmax=1, intensity=True)
measurement.visualize(title=f"z={i}f", vmax=1)
.. figure:: _static/4f_simulation.png
:width: 700px
Expand Down
97 changes: 50 additions & 47 deletions docs/source/tutorials/4f_system.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/test_planar_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_visualize(self, mock_show):

tensor = torch.randn(self.shape)

visual = plane._visualize(tensor, show=True, return_fig=True, bounds=True)
visual = plane._visualize(tensor, show=True, return_fig=True, show_bounds=True)

mock_show.assert_called_once() # Check if plt.show() was called
self.assertIsInstance(visual, matplotlib.pyplot.Figure)
Expand Down
6 changes: 3 additions & 3 deletions torchoptics/elements/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(self, field: Field) -> Tensor:
intensity_flat, weight_flat = field.intensity().flatten(-2), self.weight.flatten(-2)
return linear(intensity_flat, weight_flat) * self.cell_area() # pylint: disable=not-callable

def visualize(self, *index: int, sum_weight: bool = False, **kwargs) -> Any:
def visualize(self, *index: int, **kwargs) -> Any:
"""
Visualizes the detector output or the weight matrix.
Expand All @@ -114,8 +114,8 @@ def visualize(self, *index: int, sum_weight: bool = False, **kwargs) -> Any:
sum_weight (bool): Whether to plot the sum of the weight matrix. Default: `False`.
**kwargs: Additional keyword arguments for visualization.
"""
data = self.weight.sum(dim=0) if sum_weight else self.weight
return self._visualize(data, index, **kwargs)
kwargs.update({"symbol": "W"})
return self._visualize(self.weight, index, **kwargs)

@staticmethod
def _validate_weight(tensor):
Expand Down
2 changes: 2 additions & 0 deletions torchoptics/elements/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def visualize(self, **kwargs) -> Any:
Args:
**kwargs: Additional keyword arguments for visualization.
"""
kwargs.update({"symbol": r"\mathcal{M}"})
return self._visualize(self.modulation_profile, **kwargs)


Expand Down Expand Up @@ -110,4 +111,5 @@ def visualize(self, *index: int, **kwargs) -> Any:
*index (int): Index of the tensor to visualize.
**kwargs: Additional keyword arguments for visualization.
"""
kwargs.update({"symbol": "J"})
return self._visualize(self.polarized_modulation_profile, index, **kwargs)
6 changes: 3 additions & 3 deletions torchoptics/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def outer(self, other: Field) -> Tensor:
)
return outer2d(self.data, other.data) * self.cell_area()

def visualize(self, *index: int, intensity=False, **kwargs) -> Any:
def visualize(self, *index: int, **kwargs) -> Any:
"""
Visualizes the field.
Expand All @@ -254,8 +254,8 @@ def visualize(self, *index: int, intensity=False, **kwargs) -> Any:
intensity (bool): Whether to visualize only the intensity. Default: `False`.
**kwargs: Additional keyword arguments for visualization.
"""
data = self.intensity() if intensity else self.data
return self._visualize(data, index, **kwargs)
kwargs.update({"symbol": r"\psi"})
return self._visualize(self.data, index, **kwargs)

def copy(self, **kwargs) -> Field:
"""
Expand Down
4 changes: 2 additions & 2 deletions torchoptics/planar_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def geometry_str(self) -> str:
offset_str = f"({self.offset[0].item():.2e}, {self.offset[1].item():.2e})"
return f"shape={shape_str}, z={self.z.item():.2e}, spacing={spacing_str}, offset={offset_str}"

def _visualize(self, data: Tensor, index: tuple = (), bounds: bool = False, **kwargs) -> Any:
def _visualize(self, data: Tensor, index: tuple = (), show_bounds: bool = False, **kwargs) -> Any:
"""Visualizes the data tensor."""
if bounds:
if show_bounds:
kwargs.update({"extent": torch.cat((self.bounds()[2:], self.bounds()[:2])).cpu().detach()})
return visualize_tensor(data[index + (slice(None), slice(None))], **kwargs)
75 changes: 37 additions & 38 deletions torchoptics/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def visualize_tensor(
cmap: str = "inferno",
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
symbol: Optional[str] = None,
show: bool = True,
return_fig: bool = False,
) -> Optional[plt.Figure]:
Expand All @@ -35,6 +36,7 @@ def visualize_tensor(
cmap (str, optional): The colormap to use. Default: `"inferno"`.
xlabel (str, optional): The label for the x-axis. Default: `None`.
ylabel (str, optional): The label for the y-axis. Default: `None`.
symbol (str, optional): Symbol used in ax title. Default: `None`.
show (bool, optional): Whether to display the plot. Default: `True`.
return_fig (bool, optional): Whether to return the figure. Default: `False`.
"""
Expand All @@ -43,10 +45,39 @@ def visualize_tensor(
raise ValueError(f"Expected tensor to be 2D, but got shape {tensor.shape}.")
tensor = tensor.detach().cpu().view(tensor.shape[-2], tensor.shape[-1])

fig, axes = plt.subplots(2, 2, figsize=(10, 10)) if tensor.is_complex() else plt.subplots(figsize=(5, 5))
if tensor.is_complex():
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
tensor = torch.where(tensor == -0.0 - 0.0j, 0, tensor) # Remove numerical artifacts

create_image_subplot( # Plot absolute square
tensor.abs().square(),
extent,
vmin,
vmax,
cmap,
xlabel,
ylabel,
axes[0],
rf"$|{symbol}|^2$" if symbol is not None else None,
)

create_image_subplot( # plot angle
tensor.angle(),
extent,
-torch.pi,
torch.pi,
"twilight_shifted",
xlabel,
ylabel,
axes[1],
r"$\arg \{" + symbol + r"\}$" if symbol is not None else None,
)

subplots_func = _create_complex_image_subplots if tensor.is_complex() else _create_image_subplot
subplots_func(tensor, extent, vmin, vmax, cmap, xlabel, ylabel, axes)
axes[1].get_images()[0].set_interpolation("none")
plt.subplots_adjust(wspace=0.4, hspace=0.4)
else:
fig, axes = plt.subplots(figsize=(5, 5))
create_image_subplot(tensor, extent, vmin, vmax, cmap, xlabel, ylabel, axes, symbol)

if title:
fig.suptitle(title, y=0.95)
Expand All @@ -57,41 +88,7 @@ def visualize_tensor(
return fig if return_fig else None


def _create_complex_image_subplots(
tensor: Tensor,
extent: Optional[Sequence[float]],
vmin: Optional[float],
vmax: Optional[float],
cmap: str,
xlabel: Optional[str],
ylabel: Optional[str],
axes: Any,
) -> None:
"""Creates subplots for visualizing a complex-valued tensor."""
components = [tensor.abs().square(), tensor.angle(), tensor.real, tensor.imag]
ax_titles = [r"$|\psi|^2$", r"$\arg \{ \psi \}$", r"$\Re \{\psi \}$", r"$\Im \{\psi \}$"]
cmap_list = [cmap, "twilight_shifted", "viridis", "viridis"]
vmin_list = [vmin, -torch.pi, None, None]
vmax_list = [vmax, torch.pi, None, None]

for i in range(4):
_create_image_subplot(
components[i],
extent=extent,
vmin=vmin_list[i],
vmax=vmax_list[i],
cmap=cmap_list[i],
xlabel=xlabel,
ylabel=ylabel,
ax=axes.flat[i], # type: ignore[attr-defined]
)
axes.flat[i].set_title(ax_titles[i]) # type: ignore[attr-defined]

axes[0, 1].get_images()[0].set_interpolation("none") # type: ignore[index]
plt.subplots_adjust(wspace=0.4, hspace=0.4)


def _create_image_subplot(
def create_image_subplot(
tensor: Tensor,
extent: Optional[Sequence[float]],
vmin: Optional[float],
Expand All @@ -100,6 +97,7 @@ def _create_image_subplot(
xlabel: Optional[str],
ylabel: Optional[str],
ax: Any,
ax_title: Optional[str],
) -> None:
"""Creates a subplot for visualizing a real-valued tensor."""
extent_tuple = tuple(extent) if extent is not None else None
Expand All @@ -109,3 +107,4 @@ def _create_image_subplot(
plt.colorbar(im, cax=cax, orientation="vertical")
ax.set_xlabel(xlabel) # type: ignore[arg-type]
ax.set_ylabel(ylabel) # type: ignore[arg-type]
ax.set_title(ax_title)

0 comments on commit 3780592

Please sign in to comment.