Skip to content

Commit

Permalink
Flake8 updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim Lippeveld committed May 22, 2024
1 parent 34c8102 commit 216b0c5
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions src/scip/segmentation/cellpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@
import torch


def _get_gpu_device(worker):
if isinstance(get_client().cluster, LocalCluster):
gpu_id = '0'
else:
gpu_workers = [
address
for address, w in get_client().scheduler_info()["workers"].items()
if "cellpose" in w["resources"]
]
gpu_id = gpu_workers.index(worker.address)

return torch.device(f'cuda:{gpu_id}')


def segment_block(
events: List[Mapping[str, Any]],
*,
Expand Down Expand Up @@ -53,27 +67,16 @@ def segment_block(
if len(events) == 0:
return events

w = get_worker()
if hasattr(w, "cellpose"):
model = w.cellpose
worker = get_worker()
if hasattr(worker, "cellpose"):
model = worker.cellpose
else:
if gpu_accelerated:

if isinstance(get_client().cluster, LocalCluster):
gpu_id = '0'
else:
gpu_workers = [
address
for address, w in get_client().scheduler_info()["workers"].items()
if "cellpose" in w["resources"]
]
gpu_id = gpu_workers.index(w.address)

device = torch.device(f'cuda:{gpu_id}')
device = _get_gpu_device(worker)
model = models.Cellpose(gpu=True, device=device, model_type='cyto2')
else:
model = models.Cellpose(gpu=False, model_type='cyto2')
w.cellpose = model
worker.cellpose = model

parents, _, _, _ = model.eval(
x=[e["pixels"][[parent_channel_index, dapi_channel_index]] for e in events],
Expand Down

0 comments on commit 216b0c5

Please sign in to comment.