Skip to content

Commit

Permalink
Feature: Add File IO package (#174)
Browse files Browse the repository at this point in the history
### Description

This is a precursor for the soon to come save prediction feature. Save
functions needed to be added; then because they mirror the structure of
the read functions, the read functions have been moved from
`dataset.dataset_utils` to a new `file_io` package. This package has
`read` and `write` subpackages that mirror each other's structures, for
ease of understanding.

Additionally, the `get_read_func` is slightly refactored to match how
the `get_write_func` is implemented. Read functions are stored in a
module level dictionary with the keys being `SupportedData`, the
`get_read_func` indexes this dictionary based on the `data_type` passed.
This removes the eventuality of a long list of if/else statements. This
wasn't strictly necessary as we do not plan to support a large number of
data types, but the option is always there.

An additional extra change that snuck into this PR is renaming
`SupportedData.get_extension` to `SupportedData.get_extension_pattern`,
and adding a different `SupportedData.get_extension`. The new
`SupportedData.get_extension` returns the literal string without the
unix wildcard patterns and will be used for saving predictions in a
future PR.

- **What**: 
- Added a `file_io` package to contain functions to read and write image
files.
  - `SupportedData.get_extension` modification and addition.
- **Why**: Partly an aesthetic choice, removes the responsibility of
file reading and writing from the `datasets` package in accordance with
trying to follow the single-responsibility principle.
- **How**: Added write functions. Moved and slightly refactored read
functions.

### Changes Made

- **Added**:
  - `file_io`, `file_io.read`, `file_io.write` packages.
  - `write_tiff` function
  - `get_write_func` function
  - New `SupportedData.get_extension` to return literal extension string
  - Tests for all the above new functions
- **Modified**:
- Old `SupportedData.get_extension` renamed to
`SupportedData.get_extension_pattern`
  - Renamed tests to mirror name change in function
- **Removed**: 
  - File reading from `datasets.dataset_utils`

### Breaking changes

- Code calling read functions from `datasets.dataset_utils` directly.
- Code using `SupportedData.get_extension` directly.

### Additional Notes and Examples

In a future PR `SupportedData.get_extension` and
`SupportedData.get_extension_pattern` could be moved to the `file_io`
package. These functions do not need to be bound to `SupportedData` and
none of the other "support" `Enum` classes have additional methods. It
might make sense to store these functions closer to where they are used.
This is again a stylistic choice, feel free to share other opinions.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: jdeschamps <6367888+jdeschamps@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 5, 2024
1 parent ff20596 commit 57ceab2
Show file tree
Hide file tree
Showing 25 changed files with 320 additions and 62 deletions.
33 changes: 29 additions & 4 deletions src/careamics/config/support/supported_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def _missing_(cls, value: object) -> str:
return super()._missing_(value)

@classmethod
def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
def get_extension_pattern(cls, data_type: Union[str, SupportedData]) -> str:
"""
Path.rglob and fnmatch compatible extension.
Get Path.rglob and fnmatch compatible extension.
Parameters
----------
Expand All @@ -72,13 +72,38 @@ def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
Returns
-------
str
Corresponding extension.
Corresponding extension pattern.
"""
if data_type == cls.ARRAY:
raise NotImplementedError(f"Data {data_type} are not loaded from file.")
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
elif data_type == cls.TIFF:
return "*.tif*"
elif data_type == cls.CUSTOM:
return "*.*"
else:
raise ValueError(f"Data type {data_type} is not supported.")

@classmethod
def get_extension(cls, data_type: Union[str, SupportedData]) -> str:
"""
Get file extension of corresponding data type.
Parameters
----------
data_type : str or SupportedData
Data type.
Returns
-------
str
Corresponding extension.
"""
if data_type == cls.ARRAY:
raise NotImplementedError(f"Data '{data_type}' is not loaded from a file.")
elif data_type == cls.TIFF:
return ".tiff"
elif data_type == cls.CUSTOM:
# TODO: improve this message
raise NotImplementedError("Custom extensions have to be passed elsewhere.")
else:
raise ValueError(f"Data type {data_type} is not supported.")
6 changes: 0 additions & 6 deletions src/careamics/dataset/dataset_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
"get_files_size",
"list_files",
"validate_source_target_files",
"read_tiff",
"get_read_func",
"read_zarr",
"iterate_over_files",
"WelfordStatistics",
]
Expand All @@ -19,7 +16,4 @@
)
from .file_utils import get_files_size, list_files, validate_source_target_files
from .iterate_over_files import iterate_over_files
from .read_tiff import read_tiff
from .read_utils import get_read_func
from .read_zarr import read_zarr
from .running_stats import WelfordStatistics, compute_normalization_stats
2 changes: 1 addition & 1 deletion src/careamics/dataset/dataset_utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def list_files(
raise FileNotFoundError(f"Data path {data_path} does not exist.")

# get extension compatible with fnmatch and rglob search
extension = SupportedData.get_extension(data_type)
extension = SupportedData.get_extension_pattern(data_type)

if data_type == SupportedData.CUSTOM and extension_filter != "":
extension = extension_filter
Expand Down
2 changes: 1 addition & 1 deletion src/careamics/dataset/dataset_utils/iterate_over_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from torch.utils.data import get_worker_info

from careamics.config import DataConfig, InferenceConfig
from careamics.file_io.read import read_tiff
from careamics.utils.logging import get_logger

from .dataset_utils import reshape_array
from .read_tiff import read_tiff

logger = get_logger(__name__)

Expand Down
27 changes: 0 additions & 27 deletions src/careamics/dataset/dataset_utils/read_utils.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/careamics/dataset/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import numpy as np
from torch.utils.data import Dataset

from careamics.file_io.read import read_tiff
from careamics.transforms import Compose

from ..config import DataConfig
from ..config.transformations import NormalizeModel
from ..utils.logging import get_logger
from .dataset_utils import read_tiff
from .patching.patching import (
PatchedOutput,
Stats,
Expand Down
3 changes: 2 additions & 1 deletion src/careamics/dataset/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

from careamics.config import DataConfig
from careamics.config.transformations import NormalizeModel
from careamics.file_io.read import read_tiff
from careamics.transforms import Compose

from ..utils.logging import get_logger
from .dataset_utils import iterate_over_files, read_tiff
from .dataset_utils import iterate_over_files
from .dataset_utils.running_stats import WelfordStatistics
from .patching.patching import Stats
from .patching.random_patching import extract_patches_random
Expand Down
3 changes: 2 additions & 1 deletion src/careamics/dataset/iterable_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from numpy.typing import NDArray
from torch.utils.data import IterableDataset

from careamics.file_io.read import read_tiff
from careamics.transforms import Compose

from ..config import InferenceConfig
from ..config.transformations import NormalizeModel
from .dataset_utils import iterate_over_files, read_tiff
from .dataset_utils import iterate_over_files


class IterablePredDataset(IterableDataset):
Expand Down
3 changes: 2 additions & 1 deletion src/careamics/dataset/iterable_tiled_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from numpy.typing import NDArray
from torch.utils.data import IterableDataset

from careamics.file_io.read import read_tiff
from careamics.transforms import Compose

from ..config import InferenceConfig
from ..config.tile_information import TileInformation
from ..config.transformations import NormalizeModel
from .dataset_utils import iterate_over_files, read_tiff
from .dataset_utils import iterate_over_files
from .tiling import extract_tiles


Expand Down
7 changes: 7 additions & 0 deletions src/careamics/file_io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Functions relating reading and writing image files."""

__all__ = ["read", "write", "get_read_func", "get_write_func"]

from . import read, write
from .read import get_read_func
from .write import get_write_func
11 changes: 11 additions & 0 deletions src/careamics/file_io/read/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Functions relating to reading image files of different formats."""

__all__ = [
"get_read_func",
"read_tiff",
"read_zarr",
]

from .get_func import get_read_func
from .tiff import read_tiff
from .zarr import read_zarr
56 changes: 56 additions & 0 deletions src/careamics/file_io/read/get_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Module to get read functions."""

from pathlib import Path
from typing import Callable, Dict, Protocol, Union

from numpy.typing import NDArray

from careamics.config.support import SupportedData

from .tiff import read_tiff


# This is very strict, function signature has to match including arg names
# See WriteFunc notes
class ReadFunc(Protocol):
"""Protocol for type hinting read functions."""

def __call__(self, file_path: Path, *args, **kwargs) -> NDArray:
"""
Type hinted callables must match this function signature (not including self).
Parameters
----------
file_path : pathlib.Path
Path to file.
*args
Other positional arguments.
**kwargs
Other keyword arguments.
"""


READ_FUNCS: Dict[SupportedData, ReadFunc] = {
SupportedData.TIFF: read_tiff,
}


def get_read_func(data_type: Union[str, SupportedData]) -> Callable:
"""
Get the read function for the data type.
Parameters
----------
data_type : SupportedData
Data type.
Returns
-------
callable
Read function.
"""
if data_type in READ_FUNCS:
data_type = SupportedData(data_type) # mypy complaining about dict key type
return READ_FUNCS[data_type]
else:
raise NotImplementedError(f"Data type '{data_type}' is not supported.")
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray:
ValueError
If the axes length is incorrect.
"""
if fnmatch(file_path.suffix, SupportedData.get_extension(SupportedData.TIFF)):
if fnmatch(
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
):
try:
array = tifffile.imread(file_path)
except (ValueError, OSError) as e:
Expand Down
File renamed without changes.
9 changes: 9 additions & 0 deletions src/careamics/file_io/write/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Functions relating to writing image files of different formats."""

__all__ = [
"get_write_func",
"write_tiff",
]

from .get_func import get_write_func
from .tiff import write_tiff
59 changes: 59 additions & 0 deletions src/careamics/file_io/write/get_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Module to get write functions."""

from pathlib import Path
from typing import Protocol, Union

from numpy.typing import NDArray

from careamics.config.support import SupportedData

from .tiff import write_tiff


# This is very strict, arguments have to be called file_path & img
# Alternative? - doesn't capture *args & **kwargs
# WriteFunc = Callable[[Path, NDArray], None]
class WriteFunc(Protocol):
"""Protocol for type hinting write functions."""

def __call__(self, file_path: Path, img: NDArray, *args, **kwargs) -> None:
"""
Type hinted callables must match this function signature (not including self).
Parameters
----------
file_path : pathlib.Path
Path to file.
img : numpy.ndarray
Image data to save.
*args
Other positional arguments.
**kwargs
Other keyword arguments.
"""


WRITE_FUNCS: dict[SupportedData, WriteFunc] = {
SupportedData.TIFF: write_tiff,
}


def get_write_func(data_type: Union[str, SupportedData]) -> WriteFunc:
"""
Get the write function for the data type.
Parameters
----------
data_type : SupportedData
Data type.
Returns
-------
callable
Write function.
"""
if data_type in WRITE_FUNCS:
data_type = SupportedData(data_type) # mypy complaining about dict key type
return WRITE_FUNCS[data_type]
else:
raise NotImplementedError(f"Data type {data_type} is not supported.")
39 changes: 39 additions & 0 deletions src/careamics/file_io/write/tiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Write tiff function."""

from fnmatch import fnmatch
from pathlib import Path

import tifffile
from numpy.typing import NDArray

from careamics.config.support import SupportedData


def write_tiff(file_path: Path, img: NDArray, *args, **kwargs) -> None:
"""
Write tiff files.
Parameters
----------
file_path : pathlib.Path
Path to file.
img : numpy.ndarray
Image data to save.
*args
Positional arguments passed to `tifffile.imwrite`.
**kwargs
Keyword arguments passed to `tifffile.imwrite`.
Raises
------
ValueError
When the file extension of `file_path` does not match the Unix shell-style
pattern '*.tif*'.
"""
if not fnmatch(
file_path.suffix, SupportedData.get_extension_pattern(SupportedData.TIFF)
):
raise ValueError(
f"Unexpected extension '{file_path.suffix}' for save file type 'tiff'."
)
tifffile.imwrite(file_path, img, *args, **kwargs)
6 changes: 2 additions & 4 deletions src/careamics/lightning/predict_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
IterablePredDataset,
IterableTiledPredDataset,
)
from careamics.dataset.dataset_utils import (
get_read_func,
list_files,
)
from careamics.dataset.dataset_utils import list_files
from careamics.dataset.tiling.collate_tiles import collate_tiles
from careamics.file_io.read import get_read_func
from careamics.utils import get_logger

PredictDatasetType = Union[
Expand Down
Loading

0 comments on commit 57ceab2

Please sign in to comment.