Skip to content

Commit

Permalink
Add support for block interpolation in backward mode.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595629332
  • Loading branch information
mjanusz authored and copybara-github committed Jan 4, 2024
1 parent 00ac445 commit 45ca4b1
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions processor/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from connectomics.common import bounding_box
from connectomics.volume import subvolume
from connectomics.volume import subvolume_processor

import numpy as np
from scipy import spatial
from sofima import map_utils
Expand Down Expand Up @@ -54,13 +53,14 @@ class ReconcileCrossBlockMaps(subvolume_processor.SubvolumeProcessor):

def __init__(
self,
cross_block_volinfo,
cross_block_inv_volinfo,
last_inv_volinfo,
main_inv_volinfo,
z_map,
stride,
xy_overlap=128,
cross_block_volinfo: str,
cross_block_inv_volinfo: str,
last_inv_volinfo: str,
main_inv_volinfo: str,
z_map: dict[int | str, int | str],
stride: int,
xy_overlap: int = 128,
backward: bool = False,
input_volinfo=None,
):
"""Constructor.
Expand All @@ -80,6 +80,8 @@ def __init__(
in pixels of the output volume
xy_overlap: neighboring subvolume overlap in the XY directions, in units
of pixels of main input volume
backward: whether the mesh was solved in backward mode (proceeding from
higher z coordinates towards lower ones)
input_volinfo: path to the high-res input volume (unused)
"""
del input_volinfo
Expand All @@ -91,6 +93,7 @@ def __init__(
self._z_map = {int(k): int(v) for k, v in z_map.items()}
self._sorted_z = list(sorted(self._z_map.keys()))
self._stride = stride
self._backward = backward

def _open_volume(self, path: str) -> Any:
"""Returns a CZYX-shaped ndarray-like object."""
Expand Down Expand Up @@ -146,17 +149,30 @@ def _interpolate(
'cross_block' volume
done: set of 'z' section coordinates that have already been processed
"""
xblock_post = load_xblock(self._z_map[z1])
if z0 > 0:
if self._backward:
xblock_post = load_xblock(self._z_map[z0])
else:
xblock_post = load_xblock(self._z_map[z1])

if not self._backward and z0 > 0:
xblock_pre = load_xblock(self._z_map[z0])
xblock_pre_inv = load_xblock_inv(self._z_map[z0])
elif self._backward and z1 < self._sorted_z[-1]:
xblock_pre = load_xblock(self._z_map[z1])
xblock_pre_inv = load_xblock_inv(self._z_map[z1])
else:
xblock_pre_inv = xblock_pre = np.zeros_like(xblock_post)

if z1 != self._sorted_z[-1]:
block_end_inv = load_last_inv(z1)
if self._backward:
if z0 != self._sorted_z[0]:
block_end_inv = load_last_inv(z0)
else:
block_end_inv = load_main_inv(z0)
else:
block_end_inv = load_main_inv(z1)
if z1 != self._sorted_z[-1]:
block_end_inv = load_last_inv(z1)
else:
block_end_inv = load_main_inv(z1)

flat_box = bounding_box.BoundingBox(
start=box.start, size=(box.size[0], box.size[1], 1)
Expand Down Expand Up @@ -201,19 +217,29 @@ def _interpolate(
self._stride,
)

b = z1 - z0
block_size = z1 - z0
for z in range(max(box.start[2], z0), min(box.end[2], z1 + 1)):
i = z - z0
# Each section can be processed only once.
if z in done:
continue
rel_z = z - box.start[2]

if i == b:
data[:, rel_z : rel_z + 1, ...] = xblock_post
if i == block_size:
data[:, rel_z : rel_z + 1, ...] = (
xblock_pre if self._backward else xblock_post
)
elif i == 0:
data[:, rel_z : rel_z + 1, ...] = xblock_pre
data[:, rel_z : rel_z + 1, ...] = (
xblock_post if self._backward else xblock_pre
)
else:

if self._backward:
scale = (block_size - i) / block_size
else:
scale = i / block_size

try:
# The output coordinate map here is the inverse of the argument
# passed to warp() in the comment above, i.e.:
Expand All @@ -230,7 +256,7 @@ def _interpolate(
interior_aligned,
flat_box,
self._stride,
offset * i / b,
offset * scale,
flat_box,
self._stride,
)
Expand Down

0 comments on commit 45ca4b1

Please sign in to comment.