From 45ca4b12ba1fa9777c8c287ccb1b56fafe60b4dc Mon Sep 17 00:00:00 2001 From: Michal Januszewski Date: Thu, 4 Jan 2024 01:19:04 -0800 Subject: [PATCH] Add support for block interpolation in backward mode. PiperOrigin-RevId: 595629332 --- processor/maps.py | 62 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/processor/maps.py b/processor/maps.py index cb6145a..77eba81 100644 --- a/processor/maps.py +++ b/processor/maps.py @@ -21,7 +21,6 @@ from connectomics.common import bounding_box from connectomics.volume import subvolume from connectomics.volume import subvolume_processor - import numpy as np from scipy import spatial from sofima import map_utils @@ -54,13 +53,14 @@ class ReconcileCrossBlockMaps(subvolume_processor.SubvolumeProcessor): def __init__( self, - cross_block_volinfo, - cross_block_inv_volinfo, - last_inv_volinfo, - main_inv_volinfo, - z_map, - stride, - xy_overlap=128, + cross_block_volinfo: str, + cross_block_inv_volinfo: str, + last_inv_volinfo: str, + main_inv_volinfo: str, + z_map: dict[int | str, int | str], + stride: int, + xy_overlap: int = 128, + backward: bool = False, input_volinfo=None, ): """Constructor. @@ -80,6 +80,8 @@ def __init__( in pixels of the output volume xy_overlap: neighboring subvolume overlap in the XY directions, in units of pixels of main input volume + backward: whether the mesh was solved in backward mode (proceeding from + higher z coordinates towards lower ones) input_volinfo: path to the high-res input volume (unused) """ del input_volinfo @@ -91,6 +93,7 @@ def __init__( self._z_map = {int(k): int(v) for k, v in z_map.items()} self._sorted_z = list(sorted(self._z_map.keys())) self._stride = stride + self._backward = backward def _open_volume(self, path: str) -> Any: """Returns a CZYX-shaped ndarray-like object.""" @@ -146,17 +149,30 @@ def _interpolate( 'cross_block' volume done: set of 'z' section coordinates that have already been processed """ - xblock_post = load_xblock(self._z_map[z1]) - if z0 > 0: + if self._backward: + xblock_post = load_xblock(self._z_map[z0]) + else: + xblock_post = load_xblock(self._z_map[z1]) + + if not self._backward and z0 > 0: xblock_pre = load_xblock(self._z_map[z0]) xblock_pre_inv = load_xblock_inv(self._z_map[z0]) + elif self._backward and z1 < self._sorted_z[-1]: + xblock_pre = load_xblock(self._z_map[z1]) + xblock_pre_inv = load_xblock_inv(self._z_map[z1]) else: xblock_pre_inv = xblock_pre = np.zeros_like(xblock_post) - if z1 != self._sorted_z[-1]: - block_end_inv = load_last_inv(z1) + if self._backward: + if z0 != self._sorted_z[0]: + block_end_inv = load_last_inv(z0) + else: + block_end_inv = load_main_inv(z0) else: - block_end_inv = load_main_inv(z1) + if z1 != self._sorted_z[-1]: + block_end_inv = load_last_inv(z1) + else: + block_end_inv = load_main_inv(z1) flat_box = bounding_box.BoundingBox( start=box.start, size=(box.size[0], box.size[1], 1) @@ -201,7 +217,7 @@ def _interpolate( self._stride, ) - b = z1 - z0 + block_size = z1 - z0 for z in range(max(box.start[2], z0), min(box.end[2], z1 + 1)): i = z - z0 # Each section can be processed only once. @@ -209,11 +225,21 @@ def _interpolate( continue rel_z = z - box.start[2] - if i == b: - data[:, rel_z : rel_z + 1, ...] = xblock_post + if i == block_size: + data[:, rel_z : rel_z + 1, ...] = ( + xblock_pre if self._backward else xblock_post + ) elif i == 0: - data[:, rel_z : rel_z + 1, ...] = xblock_pre + data[:, rel_z : rel_z + 1, ...] = ( + xblock_post if self._backward else xblock_pre + ) else: + + if self._backward: + scale = (block_size - i) / block_size + else: + scale = i / block_size + try: # The output coordinate map here is the inverse of the argument # passed to warp() in the comment above, i.e.: @@ -230,7 +256,7 @@ def _interpolate( interior_aligned, flat_box, self._stride, - offset * i / b, + offset * scale, flat_box, self._stride, )