Skip to content

Commit

Permalink
Bugfix gpu selection in localcluster
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim Lippeveld committed May 22, 2024
1 parent d25a548 commit 34c8102
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/scip/segmentation/cellpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
from typing import Optional, List, Any, Mapping
import numpy
from cellpose import models
from dask.distributed import get_worker
from dask.distributed import get_worker, get_client, LocalCluster
import torch
from distributed import get_client


def segment_block(
Expand Down Expand Up @@ -59,14 +58,17 @@ def segment_block(
model = w.cellpose
else:
if gpu_accelerated:
# find all gpu enabled workers
gpu_workers = [
address
for address, w in get_client().scheduler_info()["workers"].items()
if "cellpose" in w["resources"]
]

gpu_id = gpu_workers.index(w.address)

if isinstance(get_client().cluster, LocalCluster):
gpu_id = '0'
else:
gpu_workers = [
address
for address, w in get_client().scheduler_info()["workers"].items()
if "cellpose" in w["resources"]
]
gpu_id = gpu_workers.index(w.address)

device = torch.device(f'cuda:{gpu_id}')
model = models.Cellpose(gpu=True, device=device, model_type='cyto2')
else:
Expand Down

0 comments on commit 34c8102

Please sign in to comment.