Skip to content

Commit

Permalink
Merge branch 'main' into unet_skip_gate_option
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Nov 4, 2024
2 parents c7f8388 + 108db88 commit 82468b2
Show file tree
Hide file tree
Showing 40 changed files with 2,962 additions and 231 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ jobs:
with:
fetch-depth: 0 # otherwise, you will failed to push refs to dest repo
- name: install dacapo
# run: pip install .[docs]
run: pip install sphinx-autodoc-typehints sphinx-autoapi sphinx-click sphinx-rtd-theme myst-parser jupytext ipykernel nbsphinx
# run:
run: |
pip install sphinx-autodoc-typehints sphinx-autoapi sphinx-click sphinx-rtd-theme myst-parser jupytext ipykernel nbsphinx myst_nb
python -m ipykernel install --user --name python3
pip install .[docs]
- name: parse notebooks
run: jupytext --to notebook --execute ./docs/source/notebooks/*.py
# continue-on-error: true
- name: remove notebook scripts
run: rm ./docs/source/notebooks/*.py
- name: Build and Commit
Expand Down
20 changes: 15 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,20 @@
A framework for easy application of established machine learning techniques on large, multi-dimensional images.

`dacapo` allows you to configure machine learning jobs as combinations of
[DataSplits](http://docs/api.html#datasplits),
[Architectures](http://docs/api.html#architectures),
[Tasks](http://docs/api.html#tasks),
[Trainers](http://docs/api.html#trainers),
[DataSplits](https://janelia-cellmap.github.io/dacapo/autoapi/dacapo/experiments/datasplits/index.html),
[Architectures](https://janelia-cellmap.github.io/dacapo/autoapi/dacapo/experiments/architectures/index.html),
[Tasks](https://janelia-cellmap.github.io/dacapo/autoapi/dacapo/experiments/tasks/index.html),
[Trainers](https://janelia-cellmap.github.io/dacapo/autoapi/dacapo/experiments/trainers/index.html),
on arbitrarily large volumes of
multi-dimensional images. `dacapo` is not tied to a particular learning
framework, but currently only supports [`torch`](https://pytorch.org/) with
plans to support [`tensorflow`](https://www.tensorflow.org/).


![DaCapo Diagram](https://raw.githubusercontent.com/janelia-cellmap/dacapo/main/docs/source/_static/dacapo_diagram.png)



## Installation and Setup
Currently, python>=3.10 is supported. We recommend creating a new conda environment for dacapo with python 3.10.
```
Expand All @@ -52,7 +57,12 @@ Tasks we support and approaches for those tasks:
- Semantic segmentation
- Signed distances
- One-hot encoding of different types of objects


## Example Tutorial
A minimal example tutorial can be found in the examples directory and opened in colab here: <a target="_blank" href="https://colab.research.google.com/github/janelia-cellmap/dacapo/blob/main/examples/starter_tutorial/minimal_tutorial.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Helpful Resources & Tools
- Chunked data, zarr, and n5
- OME-Zarr: a cloud-optimized bioimaging file format with international community support (doi: [10.1101/2023.02.17.528834](https://pubmed.ncbi.nlm.nih.gov/36865282/))
Expand Down
1 change: 1 addition & 0 deletions dacapo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .validate import validate, validate_run # noqa
from .predict import predict # noqa
from .blockwise import run_blockwise, segment_blockwise # noqa
from . import predict_local
1 change: 0 additions & 1 deletion dacapo/blockwise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .blockwise_task import DaCapoBlockwiseTask
from .scheduler import run_blockwise, segment_blockwise
from . import global_vars
1 change: 0 additions & 1 deletion dacapo/blockwise/global_vars.py

This file was deleted.

27 changes: 4 additions & 23 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numpy as np
import click
from dacapo.blockwise import global_vars

import logging

Expand All @@ -28,20 +27,6 @@
path = __file__


def is_global_run_set(run_name) -> bool:
if global_vars.current_run is not None:
if global_vars.current_run.name == run_name:
return True
else:
logger.error(
f"Found global run {global_vars.current_run.name} but looking for {run_name}"
)
return False
else:
logger.error("No global run is set.")
return False


@click.group()
@click.option(
"--log-level",
Expand Down Expand Up @@ -131,14 +116,10 @@ def io_loop():
compute_context = create_compute_context()
device = compute_context.device

if is_global_run_set(run_name):
logger.warning("Using global run variable")
run = global_vars.current_run
else:
logger.warning("initiating local run in predict_worker")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
logger.warning("initiating local run in predict_worker")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

if iteration is not None and compute_context.distribute_workers:
# create weights store
Expand Down
2 changes: 2 additions & 0 deletions dacapo/compute_context/local_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def device(self):
if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM
return torch.device("cpu")
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
return torch.device(self._device)
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
print(
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Expand Down
5 changes: 3 additions & 2 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from zarr.n5 import N5FSStore
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays import (
ArrayConfig,
ZarrArrayConfig,
ZarrArray,
ResampledArrayConfig,
Expand Down Expand Up @@ -916,8 +917,8 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
current_targets = self.targets
targets_str = "_".join(self.targets)

target_images = {}
target_masks = {}
target_images = dict[str, ArrayConfig]()
target_masks = dict[str, ArrayConfig]()

missing_classes = [c for c in current_targets if c not in classes]
found_classes = [c for c in current_targets if c in classes]
Expand Down
54 changes: 37 additions & 17 deletions dacapo/experiments/tasks/post_processors/argmax_post_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from upath import UPath as Path
from dacapo.blockwise import run_blockwise
import dacapo.blockwise
import daisy
from daisy import Roi, Coordinate
from funlib.persistence import open_ds
from dacapo.utils.array_utils import to_ndarray, save_ndarray
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.store.array_store import LocalArrayIdentifier
from .argmax_post_processor_parameters import ArgmaxPostProcessorParameters
Expand Down Expand Up @@ -118,31 +119,50 @@ def process(
]
)

write_size = [
b * v
for b, v in zip(
block_size[-self.prediction_array.dims :],
self.prediction_array.voxel_size,
)
]

output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[dim for dim in self.prediction_array.axes if dim != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint8,
block_size * self.prediction_array.voxel_size,
)

read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size)
# run blockwise post-processing
run_blockwise(
worker_file=str(
Path(Path(dacapo.blockwise.__file__).parent, "argmax_worker.py")
),
read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :])
input_array = open_ds(
self.prediction_array_identifier.container.path,
self.prediction_array_identifier.dataset,
)

def process_block(block):
# Apply argmax to each block of data
data = np.argmax(
to_ndarray(input_array, block.read_roi),
axis=self.prediction_array.axes.index("c"),
).astype(np.uint8)
save_ndarray(data, block.write_roi, output_array)

# Define the task for blockwise processing
task = daisy.Task(
f"argmax_{output_array.dataset}",
total_roi=self.prediction_array.roi,
read_roi=read_roi,
write_roi=read_roi,
num_workers=num_workers,
max_retries=2, # TODO: make this an option
timeout=None, # TODO: make this an option
######
input_array_identifier=self.prediction_array_identifier,
output_array_identifier=output_array_identifier,
process_function=process_block,
check_function=None,
read_write_conflict=False,
fit="overhang",
max_retries=0,
timeout=None,
)

return output_array
# Run the task blockwise
return daisy.run_blockwise([task], multiprocessing=False)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from dacapo.blockwise.scheduler import segment_blockwise
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.utils.array_utils import to_ndarray, save_ndarray
from funlib.persistence import open_ds
import daisy
import mwatershed as mws

from .watershed_post_processor_parameters import WatershedPostProcessorParameters
from .post_processor import PostProcessor
Expand Down Expand Up @@ -123,29 +127,15 @@ def process(
np.uint64,
block_size * self.prediction_array.voxel_size,
)
input_array = open_ds(
self.prediction_array_identifier.container.path,
self.prediction_array_identifier.dataset,
)

read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size)
# run blockwise prediction
pars = {
"offsets": self.offsets,
"bias": parameters.bias,
"context": parameters.context,
}
segment_blockwise(
segment_function_file=str(
Path(Path(dacapo.blockwise.__file__).parent, "watershed_function.py")
),
context=parameters.context,
total_roi=self.prediction_array.roi,
read_roi=read_roi.grow(parameters.context, parameters.context),
write_roi=read_roi,
num_workers=num_workers,
max_retries=2, # TODO: make this an option
timeout=None, # TODO: make this an option
######
input_array_identifier=self.prediction_array_identifier,
output_array_identifier=output_array_identifier,
parameters=pars,
data = to_ndarray(input_array, output_array.roi).astype(float)
segmentation = mws.agglom(
data - parameters.bias, offsets=self.offsets, randomized_strides=True
)
save_ndarray(segmentation, self.prediction_array.roi, output_array)

return output_array
return output_array_identifier
7 changes: 7 additions & 0 deletions dacapo/experiments/trainers/gp_augments/elastic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class ElasticAugmentConfig(AugmentConfig):
"3D rotations."
},
)
augmentation_probability: float = attr.ib(
default=1.,
metadata={
"help_text": "Probability of applying the augmentations."
},
)

def node(self, _raw_key=None, _gt_key=None, _mask_key=None):
"""
Expand All @@ -87,4 +93,5 @@ def node(self, _raw_key=None, _gt_key=None, _mask_key=None):
rotation_interval=self.rotation_interval,
subsample=self.subsample,
uniform_3d_rotation=self.uniform_3d_rotation,
augmentation_probability=self.augmentation_probability,
)
7 changes: 7 additions & 0 deletions dacapo/experiments/trainers/gp_augments/intensity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ class IntensityAugmentConfig(AugmentConfig):
"help_text": "Set to False if modified values should not be clipped to [0, 1]"
},
)
augmentation_probability: float = attr.ib(
default=1.,
metadata={
"help_text": "Probability of applying the augmentation."
},
)

def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None):
"""
Expand All @@ -58,4 +64,5 @@ def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None):
shift_min=self.shift[0],
shift_max=self.shift[1],
clip=self.clip,
p=self.augmentation_probability,
)
9 changes: 8 additions & 1 deletion dacapo/experiments/trainers/gp_augments/simple_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ class SimpleAugmentConfig(AugmentConfig):
This class is a subclass of AugmentConfig.
"""

augmentation_probability: float = attr.ib(
default=1.,
metadata={
"help_text": "Probability of applying the augmentations."
},
)

def node(self, _raw_key=None, _gt_key=None, _mask_key=None):
"""
Get a gp.SimpleAugment node.
Expand All @@ -36,4 +43,4 @@ def node(self, _raw_key=None, _gt_key=None, _mask_key=None):
>>> node = simple_augment_config.node()
"""
return gp.SimpleAugment()
return gp.SimpleAugment(p=self.augmentation_probability)
8 changes: 4 additions & 4 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def iterate(self, num_iterations, model, optimizer, device):
"""
t_start_fetch = time.time()

print("Starting iteration!")
logger.debug("Starting iteration!")

for iteration in range(self.iteration, self.iteration + num_iterations):
raw, gt, target, weight, mask = self.next()
Expand All @@ -309,12 +309,12 @@ def iterate(self, num_iterations, model, optimizer, device):
param.grad = None

t_start_prediction = time.time()
predicted = model.forward(torch.as_tensor(raw[raw.roi]).to(device).float())
predicted = model.forward(torch.as_tensor(raw[raw.roi]).float().to(device))
predicted.retain_grad()
loss = self._loss.compute(
predicted,
torch.as_tensor(target[target.roi]).to(device).float(),
torch.as_tensor(weight[weight.roi]).to(device).float(),
torch.as_tensor(target[target.roi]).float().to(device),
torch.as_tensor(weight[weight.roi]).float().to(device),
)
loss.backward()
optimizer.step()
Expand Down
11 changes: 4 additions & 7 deletions dacapo/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,21 +426,18 @@ def plot_runs(
criteria=run.validation_score_name
)
colors_val = itertools.cycle(plt.cm.tab20.colors)
for dataset, color_v in zip(run.validation_scores.datasets, colors_val):
for dataset in run.validation_scores.datasets:
dataset_data = validation_score_data.sel(datasets=dataset)
include_validation_figure = True
x = [score.iteration for score in run.validation_scores.scores]
cc = next(colors_val)
for i in range(dataset_data.data.shape[1]):
current_name = (
f"{i}_{dataset.name}_{name}_{run.validation_score_name}"
)
for i, cc in zip(range(dataset_data.data.shape[1]), colors_val):
current_name = f"{i}_{dataset.name}"
validation_ax.plot(
x,
dataset_data.data[:, i],
label=current_name,
color=cc,
alpha=0.5 + 0.2 * i,
# alpha=0.5 + 0.2 * i,
)

if include_loss_figure:
Expand Down
Loading

0 comments on commit 82468b2

Please sign in to comment.