Skip to content

Commit

Permalink
More robust tensor conversion and updates in training script
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jul 11, 2024
1 parent da53c11 commit f06b922
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
9 changes: 7 additions & 2 deletions experiments/unet-segmentation/dsb/train_boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ def train_boundaries(args):

patch_shape = (1, 256, 256)
train_loader = get_dsb_loader(
args.input, patch_shape, split="train",
args.input, patch_shape=patch_shape, split="train",
download=True, boundaries=True, batch_size=args.batch_size
)

# Uncomment this for checking the loader.
# from torch_em.util.debug import check_loader
# check_loader(train_loader, 4)

val_loader = get_dsb_loader(
args.input, patch_shape, split="test",
args.input, patch_shape=patch_shape, split="test",
boundaries=True, batch_size=args.batch_size
)
loss = torch_em.loss.DiceLoss()
Expand Down
1 change: 1 addition & 0 deletions torch_em/util/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _check_napari(loader, n_samples, instance_labels, model=None, device=None, r
v.add_image(y)
if pred is not None:
v.add_image(pred)

napari.run()


Expand Down
9 changes: 8 additions & 1 deletion torch_em/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ def ensure_tensor(tensor, dtype=None):
if isinstance(tensor, np.ndarray):
if np.dtype(tensor.dtype) in DTYPE_MAP:
tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
tensor = torch.from_numpy(tensor)
# Try to convert the tensor, even if it has wrong byte-order
try:
tensor = torch.from_numpy(tensor)
except ValueError:
tensor = tensor.view(tensor.dtype.newbyteorder())
if np.dtype(tensor.dtype) in DTYPE_MAP:
tensor = tensor.astype(DTYPE_MAP[tensor.dtype])
tensor = torch.from_numpy(tensor)

assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch"
if dtype is not None:
Expand Down

0 comments on commit f06b922

Please sign in to comment.