Skip to content

Commit

Permalink
manage large reconstructions
Browse files Browse the repository at this point in the history
  • Loading branch information
talonchandler committed Nov 8, 2024
1 parent 9b3fcc4 commit 75be3e5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 19 deletions.
36 changes: 26 additions & 10 deletions examples/models/inplane_oriented_thick_pol3d_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,42 @@
intensity_to_stokes_matrix,
)

# Reconstruct
fzyx_object_recon = (
inplane_oriented_thick_pol3d_vector.apply_inverse_transfer_function(
szyx_data,
singular_system,
intensity_to_stokes_matrix,
regularization_strength=1e-1,
)
)
# from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer

# add_transfer_function_to_viewer(
# viewer,
# singular_system[1],
# zyx_scale=(z_pixel_size, yx_pixel_size, yx_pixel_size),
# layer_name="Singular Values",
# )
# import pdb; pdb.set_trace()


# Display
arrays = [
(fzyx_object_recon, "Object - recon"),
(szyx_data, "Data"),
(fzyx_object, "Object"),
]

for array in arrays:
viewer.add_image(torch.real(array[0]).cpu().numpy(), name=array[1])


# Reconstruct
for reg_strength in [0.005, 0.008, 0.01, 0.05, 0.1]:
fzyx_object_recon = (
inplane_oriented_thick_pol3d_vector.apply_inverse_transfer_function(
szyx_data,
singular_system,
intensity_to_stokes_matrix,
regularization_strength=reg_strength,
)
)
viewer.add_image(
torch.real(fzyx_object_recon).cpu().numpy(),
name=f"Object - recon, reg_strength={reg_strength}",
)

viewer.grid.enabled = True
viewer.grid.shape = (2, 5)
import pdb
Expand Down
60 changes: 51 additions & 9 deletions waveorder/models/inplane_oriented_thick_pol3d_vector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import tqdm
import numpy as np

from torch import Tensor
from typing import Literal
from torch.nn.functional import avg_pool3d
from torch.nn.functional import avg_pool3d, interpolate
from waveorder import optics, sampling, stokes, util
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer

Expand Down Expand Up @@ -40,6 +41,7 @@ def calculate_transfer_function(
numerical_aperture_detection,
invert_phase_contrast=False,
fourier_oversample_factor=1,
transverse_downsample_factor=1,
):
if z_padding != 0:
raise NotImplementedError("Padding not implemented for this model")
Expand All @@ -58,15 +60,34 @@ def calculate_transfer_function(
yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist))
z_factor = int(np.ceil(z_pixel_size / axial_nyquist))

print("YX factor:", yx_factor)
print("Z factor:", z_factor)

tf_calculation_shape = (
zyx_shape[0] * z_factor * fourier_oversample_factor,
int(
np.ceil(
zyx_shape[1]
* yx_factor
* fourier_oversample_factor
/ transverse_downsample_factor
)
),
int(
np.ceil(
zyx_shape[2]
* yx_factor
* fourier_oversample_factor
/ transverse_downsample_factor
)
),
)

sfZYX_transfer_function, intensity_to_stokes_matrix = (
_calculate_wrap_unsafe_transfer_function(
swing,
scheme,
(
zyx_shape[0] * z_factor * fourier_oversample_factor,
zyx_shape[1] * yx_factor * fourier_oversample_factor,
zyx_shape[2] * yx_factor * fourier_oversample_factor,
),
tf_calculation_shape,
yx_pixel_size / yx_factor,
z_pixel_size / z_factor,
wavelength_illumination,
Expand Down Expand Up @@ -96,11 +117,29 @@ def calculate_transfer_function(
pooled_sfZYX_transfer_function.shape[1],
zyx_shape[0] + 2 * z_padding,
) + zyx_shape[1:]

cropped = sampling.nd_fourier_central_cuboid(
pooled_sfZYX_transfer_function, sfzyx_out_shape
)

# Compute singular system on cropped and downsampled
U, S, Vh = calculate_singular_system(cropped)

# Interpolate to final size in YX
def complex_interpolate(tensor, zyx_shape):
interpolated_real = interpolate(tensor.real, size=zyx_shape)
interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
return interpolated_real + 1j * interpolated_imag

full_cropped = complex_interpolate(cropped, zyx_shape)
full_U = complex_interpolate(U, zyx_shape)
full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
full_Vh = complex_interpolate(Vh, zyx_shape)

return (
sampling.nd_fourier_central_cuboid(
pooled_sfZYX_transfer_function, sfzyx_out_shape
),
full_cropped,
intensity_to_stokes_matrix,
(full_U, full_S, full_Vh),
)


Expand Down Expand Up @@ -142,6 +181,7 @@ def _calculate_wrap_unsafe_transfer_function(
z_frequencies = torch.fft.fftfreq(z_total, d=z_pixel_size)

# 2D pupils
print("\tCalculating pupils...")
ill_pupil = optics.generate_pupil(
radial_frequencies,
numerical_aperture_illumination,
Expand Down Expand Up @@ -187,6 +227,8 @@ def _calculate_wrap_unsafe_transfer_function(

P_3D = torch.abs(torch.fft.ifft(P, dim=-3)).type(torch.complex64)
S_3D = torch.fft.ifft(S, dim=-3)

print("\tCalculating greens tensor spectrum...")
G_3D = optics.generate_greens_tensor_spectrum(
zyx_shape=(z_total, zyx_shape[1], zyx_shape[2]),
zyx_pixel_size=(z_pixel_size, yx_pixel_size, yx_pixel_size),
Expand Down

0 comments on commit 75be3e5

Please sign in to comment.