Skip to content

Commit

Permalink
Update brain heatmap plotting function
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhardcastle committed Dec 4, 2024
1 parent 990faee commit 1d9e3a8
Show file tree
Hide file tree
Showing 3 changed files with 1,914 additions and 762 deletions.
32 changes: 16 additions & 16 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

201 changes: 185 additions & 16 deletions src/dynamic_routing_analysis/ccf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def get_ccf_volume(left_hemisphere=True, right_hemisphere=False) -> npt.NDArray:
f"{path.suffix} files not supported, must be one of {supported}"
)
if path.protocol: # cloud path - download it
tempdir = tempfile.mkdtemp()
temp_path = upath.UPath(tempdir) / path.name
logger.info(f"Downloading CCF volume to temporary file {temp_path.as_posix()}")
temp_path.write_bytes(path.read_bytes())
path = temp_path
logger.info(f"Using CCF volume from {path.as_posix()}")

logger.info(f"Loading CCF volume from {path.as_posix()}")
volume, _ = nrrd.read(path, index_order="C") # ml, dv, ap
with tempfile.TemporaryDirectory() as tempdir:
temp_path = upath.UPath(tempdir) / path.name
logger.warning(f"Downloading CCF volume to temporary file {temp_path.as_posix()}")
temp_path.write_bytes(path.read_bytes())
path = temp_path
volume, _ = nrrd.read(path, index_order="C") # ml, dv, ap
else:
logger.info(f"Using CCF volume from {path.as_posix()}")
volume, _ = nrrd.read(path, index_order="C") # ml, dv, ap
ml_dim = AXIS_TO_DIM["ml"]
dims = [
(slice(0, volume.shape[ml_dim] // 2) if dim == ml_dim else slice(None))
Expand All @@ -77,6 +77,7 @@ def get_ccf_volume(left_hemisphere=True, right_hemisphere=False) -> npt.NDArray:
return volume


@functools.cache
def get_midline_ccf_ml() -> float:
return (
RESOLUTION_UM
Expand All @@ -88,15 +89,21 @@ def get_midline_ccf_ml() -> float:
)


def ccf_to_volume_index(coord: float) -> int:
return round(coord / RESOLUTION_UM)


@functools.cache
def get_ccf_structure_tree_df() -> pl.DataFrame:
t0 = time.time()
path = "https://raw.githubusercontent.com/cortex-lab/allenCCF/master/structure_tree_safe_2017.csv"
logging.info(f"Using CCF structure tree from {path}")
return (
pl.read_csv(path)
.with_columns(
df = pl.read_csv(path)
len_0 = len(df)
df = (
df.with_columns(
color_hex_int=pl.col("color_hex_triplet").str.to_integer(base=16),
color_hex_str=pl.lit("0x") + pl.col("color_hex_triplet"),
color_hex_str=pl.lit("#") + pl.col("color_hex_triplet"),
)
.with_columns(
r=pl.col("color_hex_triplet")
Expand All @@ -116,13 +123,52 @@ def get_ccf_structure_tree_df() -> pl.DataFrame:
color_rgb=pl.concat_list("r", "g", "b"),
)
.drop("r", "g", "b")
.with_columns(
parent_ids=pl.col("structure_id_path")
.str.split("/")
.cast(pl.List(int))
.list.drop_nulls()
.list.slice(offset=0, length=pl.col("depth")),
)
)
df = df.join(
other=(
df.explode(pl.col("parent_ids"))
.group_by(pl.col("parent_ids").alias("id"), maintain_order=True)
.agg(pl.col("id").alias("child_ids"))
),
on="id",
how="left",
).with_columns(
pl.col("child_ids").fill_null([]),
is_deepest=~pl.col("id").is_in(df["parent_structure_id"]),
)
assert not any(df.filter(pl.col("is_deepest"))["child_ids"].to_list())
# add list of deepest children for each area
df = df.join(
other=(
df.explode("child_ids")
.filter(pl.col("child_ids").is_in(df.filter(pl.col("is_deepest"))["id"]))
.group_by("id", maintain_order=True)
.agg(
pl.all().exclude("child_ids").first(),
pl.col("child_ids").alias("deepest_child_ids"),
)
),
on="id",
how="left",
).with_columns(
pl.col("deepest_child_ids").fill_null([]),
)
assert len(df) == len_0
logger.info(f"CCF structure tree loaded in {time.time() - t0:.2f}s")
return df


def get_ccf_structure_info(ccf_acronym_or_id: str | int) -> dict:
"""
>>> get_ccf_structure_info('MOs')
{'id': 993, 'atlas_id': 831, 'name': 'Secondary motor area', 'acronym': 'MOs', 'st_level': None, 'ontology_id': 1, 'hemisphere_id': 3, 'weight': 8690, 'parent_structure_id': 500, 'depth': 7, 'graph_id': 1, 'graph_order': 24, 'structure_id_path': '/997/8/567/688/695/315/500/993/', 'color_hex_triplet': '1F9D5A', 'neuro_name_structure_id': None, 'neuro_name_structure_id_path': None, 'failed': 'f', 'sphinx_id': 25, 'structure_name_facet': 1043755260, 'failed_facet': 734881840, 'safe_name': 'Secondary motor area', 'color_hex_int': 2071898, 'color_hex_str': '0x1F9D5A', 'color_rgb': [0.12156862745098039, 0.615686274509804, 0.3529411764705882]}
{'id': 993, 'atlas_id': 831, 'name': 'Secondary motor area', 'acronym': 'MOs', 'st_level': None, 'ontology_id': 1, 'hemisphere_id': 3, 'weight': 8690, 'parent_structure_id': 500, 'depth': 7, 'graph_id': 1, 'graph_order': 24, 'structure_id_path': '/997/8/567/688/695/315/500/993/', 'color_hex_triplet': '1F9D5A', 'neuro_name_structure_id': None, 'neuro_name_structure_id_path': None, 'failed': 'f', 'sphinx_id': 25, 'structure_name_facet': 1043755260, 'failed_facet': 734881840, 'safe_name': 'Secondary motor area', 'color_hex_int': 2071898, 'color_hex_str': '#1F9D5A', 'color_rgb': [0.12156862745098039, 0.615686274509804, 0.3529411764705882]}
"""
if not isinstance(ccf_acronym_or_id, int):
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
Expand All @@ -140,7 +186,53 @@ def get_ccf_structure_info(ccf_acronym_or_id: str | int) -> dict:
return results[0].limit(1).to_dicts()[0]


def get_all_parents(ccf_acronym_or_id: str | int) -> list[str]:
"""
>>> get_all_parents('MOs2/3')
['root', 'grey', 'CH', 'CTX', 'CTXpl', 'Isocortex', 'MO', 'MOs']
"""
info = get_ccf_structure_info(ccf_acronym_or_id)
parent_ids = [int(id_) for id_ in info["structure_id_path"].split("/")[1:-2]]
parent_acronyms = (
get_ccf_structure_tree_df()
.filter(
pl.col("id").is_in(parent_ids),
)["acronym"]
.to_list()
)
assert info["id"] not in parent_acronyms
return parent_acronyms


def get_all_children(ccf_acronym_or_id: str | int) -> list[str]:
"""
>>> get_all_children('MOs')
['MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b']
"""
if not isinstance(ccf_acronym_or_id, int):
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
else:
ccf_id = ccf_acronym_or_id
children = (
get_ccf_structure_tree_df()
.filter(
pl.col("structure_id_path").str.contains(f"/{ccf_id}/"),
~pl.col("structure_id_path").str.ends_with(f"/{ccf_id}/"),
)["acronym"]
.to_list()
)
assert str(ccf_acronym_or_id) not in children
return children


def get_deepest_children(ccf_acronym_or_id: str | int) -> list[str]:
"""
>>> get_deepest_children('MOs')
['MOs1', 'MOs2/3', 'MOs5', 'MOs6a', 'MOs6b']
>>> get_deepest_children('MOs1')
['MOs1']
>>> assert 'VISpor' not in get_deepest_children('VIS')
"""
if not isinstance(ccf_acronym_or_id, int):
try:
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
Expand All @@ -160,6 +252,38 @@ def get_deepest_children(ccf_acronym_or_id: str | int) -> list[str]:
].to_list()


def group_child_labels_in_slice(
slice_array: npt.NDArray[np.uint32],
acronyms_or_ids: Iterable[str | int],
) -> npt.NDArray[np.uint32]:
"""
For a given slice and CCF areas (acronyms or IDs), return a new slice with the labels grouped,
so that all areas in the same group have the same label. For example, passing ["MOS"] would
change the label index of all child areas in the MOs tree to have the MOs value.
>>> mos_id = get_ccf_structure_info('MOs')['id']
>>> slice_array = get_ccf_volume()[:, :, 100]
>>> assert mos_id not in slice_array
>>> new_slice = group_child_labels_in_slice(slice_array, ['MOs'])
>>> assert mos_id in new_slice
"""
slice_array = slice_array.copy()
for ccf_acronym_or_id in acronyms_or_ids:
if not isinstance(ccf_acronym_or_id, int):
ccf_id: int = convert_ccf_acronyms_or_ids(ccf_acronym_or_id)
else:
ccf_id = ccf_acronym_or_id
children = get_ccf_immediate_children_ids(ccf_id)
if ccf_id in children:
children.remove(ccf_id)
logger.debug(
f"Grouping {children} under {convert_ccf_acronyms_or_ids(ccf_acronym_or_id)}"
)
for child in children:
slice_array[slice_array == child] = ccf_id
return slice_array


def get_ccf_immediate_children_ids(ccf_acronym_or_id: str | int) -> set[int]:
"""
>>> ids = get_ccf_immediate_children_ids('MOs')
Expand Down Expand Up @@ -476,13 +600,58 @@ def get_scatter_image(
return image


def project_first_nonzero_labels(
volume: npt.NDArray,
axis: int = AXIS_TO_DIM["dv"],
) -> npt.NDArray:
"""
Project the first non-zero label encountered from one side of the 3D volume.
Parameters:
volume (np.ndarray): 3D array containing non-zero area labels.
axis (int): Axis along which to project (0, 1, or 2).
Returns:
np.ndarray: 2D array with the projected labels.
"""
if volume.ndim != 3:
raise ValueError(f"Volume must be 3D: {volume.shape=}")
dims = tuple(range(volume.ndim))
if axis not in dims:
raise ValueError("Axis must be 0, 1, or 2.")
plane_dims = [d for d in dims if d != axis]
mask = volume > 0
idx_along_projection_axis = np.argmax(mask, axis=axis)
idx_in_plane_axes = [np.arange(volume.shape[d]) for d in plane_dims]
if axis == 0:
projection = volume[
idx_along_projection_axis,
idx_in_plane_axes[0][:, None],
idx_in_plane_axes[1],
]
elif axis == 1:
projection = volume[
idx_in_plane_axes[0][:, None],
idx_along_projection_axis,
idx_in_plane_axes[1],
]
elif axis == 2:
projection = volume[
idx_in_plane_axes[0][:, None],
idx_in_plane_axes[1],
idx_along_projection_axis,
]
projection = projection.astype(float)
projection[projection == 0] = np.nan
return projection


if __name__ == "__main__":
logging.basicConfig(
level=logging.DEBUG,
level=logging.WARNING,
format="%(asctime)s | %(name)s | %(levelname)s | %(funcName)s | %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
)
logging.getLogger().setLevel(logging.DEBUG)

import doctest

Expand Down
Loading

0 comments on commit 1d9e3a8

Please sign in to comment.