Skip to content

Commit

Permalink
Merge pull request #2761 from bmorris3/data-assoc
Browse files Browse the repository at this point in the history
Implement associations between Data layers
  • Loading branch information
bmorris3 authored Mar 21, 2024
2 parents 6c09924 + 00785d7 commit 9bb02fc
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 15 deletions.
85 changes: 81 additions & 4 deletions jdaviz/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def __init__(self, configuration=None, *args, **kwargs):
self.hub.subscribe(self, SubsetUpdateMessage,
handler=lambda msg: self._clear_object_cache(msg.subset.label))

# Store for associations between Data entries:
self._data_associations = self._init_data_associations()

# Subscribe to messages that result in changes to the layers
self.hub.subscribe(self, AddDataMessage,
handler=self._on_layers_changed)
Expand Down Expand Up @@ -485,9 +488,14 @@ def _on_layers_changed(self, msg):
if hasattr(msg, 'data'):
layer_name = msg.data.label
is_wcs_only = msg.data.meta.get(_wcs_only_label, False)
is_not_child = self._get_assoc_data_parent(layer_name) is None
children_layers = self._get_assoc_data_children(layer_name)

elif hasattr(msg, 'subset'):
layer_name = msg.subset.label
is_wcs_only = False
is_not_child = True
children_layers = []
else:
raise NotImplementedError(f"cannot recognize new layer from {msg}")

Expand All @@ -502,13 +510,28 @@ def _on_layers_changed(self, msg):
self.state.layer_icons = {**self.state.layer_icons,
layer_name: orientation_icons.get(layer_name,
wcs_only_refdata_icon)}
else:
elif is_not_child:
self.state.layer_icons = {
**self.state.layer_icons,
layer_name: alpha_index(len([ln for ln, ic in self.state.layer_icons.items()
if not ic.startswith('mdi-')]))
if not ic.startswith('mdi-') and
self._get_assoc_data_parent(ln) is None]))
}

# all remaining layers at this point have a parent:
child_layer_icons = {}
for layer_name in self.state.layer_icons:
children_layers = self._get_assoc_data_children(layer_name)
if children_layers is not None:
parent_icon = self.state.layer_icons[layer_name]
for i, child_layer in enumerate(children_layers, start=1):
child_layer_icons[child_layer] = f'{parent_icon}{i}'

self.state.layer_icons = {
**self.state.layer_icons,
**child_layer_icons
}

def _change_reference_data(self, new_refdata_label, viewer_id=None):
"""
Change reference data to Data with ``data_label``.
Expand Down Expand Up @@ -1263,7 +1286,7 @@ def merge_overlapping_spectral_regions(self, subset_name, att):

return new_state

def add_data(self, data, data_label=None, notify_done=True):
def add_data(self, data, data_label=None, notify_done=True, parent=None):
"""
Add data to the Glue ``DataCollection``.
Expand All @@ -1278,10 +1301,12 @@ def add_data(self, data, data_label=None, notify_done=True):
The name associated with this data. If none is given, label is pulled
from the input data (if `~glue.core.data.Data`) or a generic name is
generated.
notify_done: bool
notify_done : bool
Flag controlling whether a snackbar message is set when the data is
added to the app. Set to False to avoid overwhelming the user if
lots of data is getting loaded at once.
parent : str, optional
Associate the added Data entry as the child of layer ``parent``.
"""

if not data_label and hasattr(data, "label"):
Expand All @@ -1292,6 +1317,23 @@ def add_data(self, data, data_label=None, notify_done=True):

self.data_collection[data_label] = data

# manage associated Data entries:
self._add_assoc_data_as_parent(data_label)
if parent is not None:
data_collection_labels = [data.label for data in self.data_collection]
if parent not in data_collection_labels:
raise ValueError(f'parent "{parent}" is not a valid data label in '
f'the data collection: {data_collection_labels}.')

# Does the parent Data have a parent? If so, raise error:
parent_of_parent = self._get_assoc_data_parent(parent)
if parent_of_parent is not None:
raise NotImplementedError('Data associations are currently supported '
'between root layers (without parents) and their '
f'children, but the proposed parent "{parent}" has '
f'parent "{parent_of_parent}".')
self._set_assoc_data_as_child(data_label, new_parent_label=parent)

# Send out a toast message
if notify_done:
snackbar_message = SnackbarMessage(
Expand Down Expand Up @@ -2010,6 +2052,17 @@ def set_data_visibility(self, viewer_reference, data_label, visible=True, replac
if layer.layer.data.label != data_label:
layer.visible = False

# if Data has children, update their visibilities to match Data:
assoc_children = self._get_assoc_data_children(data_label)
for layer in viewer.layers:
for data_label in assoc_children:
if layer.layer.data.label == data_label:
if visible and not layer.visible:
layer.visible = True
layer.update()
else:
layer.visible = visible

# update data menu - selected_data_items should be READ ONLY, not modified by the user/UI
selected_items = viewer_item['selected_data_items']
data_id = self._data_id_from_label(data_label)
Expand Down Expand Up @@ -2595,3 +2648,27 @@ def get_tray_item_from_name(self, name):
raise KeyError(f'{name} not found in app.state.tray_items')

return tray_item

def _init_data_associations(self):
# assume all Data are parents:
data_associations = {
data.label: {'parent': None, 'children': []}
for data in self.data_collection
}
return data_associations

def _add_assoc_data_as_parent(self, data_label):
self._data_associations[data_label] = {'parent': None, 'children': []}

def _set_assoc_data_as_child(self, data_label, new_parent_label):
# Data has a new parent:
self._data_associations[data_label]['parent'] = new_parent_label
# parent has a new child:
self._data_associations[new_parent_label]['children'].append(data_label)

def _get_assoc_data_children(self, data_label):
# intentionally not recursive for now, just one generation:
return self._data_associations.get(data_label, {}).get('children', [])

def _get_assoc_data_parent(self, data_label):
return self._data_associations.get(data_label, {}).get('parent')
25 changes: 17 additions & 8 deletions jdaviz/configs/imviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


@data_parser_registry("imviz-data-parser")
def parse_data(app, file_obj, ext=None, data_label=None):
def parse_data(app, file_obj, ext=None, data_label=None, parent=None):
"""Parse a data file into Imviz.
Parameters
Expand Down Expand Up @@ -74,27 +74,27 @@ def parse_data(app, file_obj, ext=None, data_label=None):
else: # Assume RGB
pf = rgb2gray(im)
pf = pf[::-1, :] # Flip it
_parse_image(app, pf, data_label, ext=ext)
_parse_image(app, pf, data_label, ext=ext, parent=parent)

elif file_obj_lower.endswith('.asdf'):
try:
if HAS_ROMAN_DATAMODELS:
with rdd.open(file_obj) as pf:
_parse_image(app, pf, data_label, ext=ext)
_parse_image(app, pf, data_label, ext=ext, parent=parent)
except TypeError:
# if roman_datamodels cannot parse the file, load it with asdf:
with asdf.open(file_obj) as af:
_parse_image(app, af, data_label, ext=ext)
_parse_image(app, af, data_label, ext=ext, parent=parent)

elif file_obj_lower.endswith('.reg'):
# This will load DS9 regions as Subset but only if there is already data.
app._jdaviz_helper.load_regions_from_file(file_obj)

else: # Assume FITS
with fits.open(file_obj) as pf:
_parse_image(app, pf, data_label, ext=ext)
_parse_image(app, pf, data_label, ext=ext, parent=parent)
else:
_parse_image(app, file_obj, data_label, ext=ext)
_parse_image(app, file_obj, data_label, ext=ext, parent=parent)


def get_image_data_iterator(app, file_obj, data_label, ext=None):
Expand Down Expand Up @@ -168,7 +168,7 @@ def get_image_data_iterator(app, file_obj, data_label, ext=None):
return data_iter


def _parse_image(app, file_obj, data_label, ext=None):
def _parse_image(app, file_obj, data_label, ext=None, parent=None):
if app is None:
raise ValueError("app is None, cannot proceed")
if data_label is None:
Expand All @@ -186,7 +186,16 @@ def _parse_image(app, file_obj, data_label, ext=None):
data.coords.bounding_box = None
if not data.meta.get(_wcs_only_label, False):
data_label = app.return_data_label(data_label, alt_name="image_data")
app.add_data(data, data_label)

# TODO: generalize/centralize this for use in other configs too
if parent is not None and ext == 'DQ':
# nans are used to mark "good" flags in the DQ colormap, so
# convert DQ array to float to support nans:
cid = data.get_component("DQ")
data_arr = np.float32(cid.data)
data_arr[data_arr == 0] = np.nan
data.update_components({cid: data_arr})
app.add_data(data, data_label, parent=parent)

# Do not link image data here. We do it at the end in Imviz.load_data()

Expand Down
18 changes: 15 additions & 3 deletions jdaviz/configs/imviz/plugins/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,16 @@ def blink_once(self, reversed=False):
# Simple blinking of images - this will make it so that only one
# layer is visible at a time and cycles through the layers.

# Exclude Subsets: They are global
# Exclude Subsets (they are global) and children via associated data

def is_parent(data):
return self.session.jdaviz_app._get_assoc_data_parent(data.label) is None

valid = [ilayer for ilayer, layer in enumerate(self.state.layers)
if layer_is_image_data(layer.layer)]
if layer_is_image_data(layer.layer) and is_parent(layer.layer)]
children = [ilayer for ilayer, layer in enumerate(self.state.layers)
if layer_is_image_data(layer.layer) and not is_parent(layer.layer)]

n_layers = len(valid)

if n_layers == 1:
Expand Down Expand Up @@ -116,7 +123,12 @@ def blink_once(self, reversed=False):
next_layer = valid[(valid.index(visible[-1]) + delta) % n_layers]
self.state.layers[next_layer].visible = True

for ilayer in (set(valid) - set([next_layer])):
# make invisible all parent layers other than the next layer:
layers_to_set_not_visible = set(valid) - set([next_layer])
# no child layers are visible by default:
layers_to_set_not_visible.update(set(children))

for ilayer in layers_to_set_not_visible:
self.state.layers[ilayer].visible = False

# We can display the active data label in Compass plugin.
Expand Down
12 changes: 12 additions & 0 deletions jdaviz/core/template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,12 @@ def __init__(self, plugin, items, selected, viewer,
self._update_layer_items()
self.update_wcs_only_filter(only_wcs_layers)

# ignore layers that are children in associations:
def is_parent(data):
return self.app._get_assoc_data_parent(data.label) is None

self.add_filter(is_parent)

def _get_viewer(self, viewer):
# newer will likely be the viewer name in most cases, but viewer id in the case
# of additional viewers in imviz.
Expand Down Expand Up @@ -3046,6 +3052,12 @@ def __init__(self, plugin, items, selected,
# initialize items from original viewers
self._on_data_changed()

# ignore layers that are children in associations:
def is_parent(data):
return self.app._get_assoc_data_parent(data.label) is None

self.add_filter(is_parent)

def _cubeviz_include_spatial_subsets(self):
"""
Call this method to prepend spatial subsets to the list of datasets (and listen for newly
Expand Down
22 changes: 22 additions & 0 deletions jdaviz/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import numpy as np
from jdaviz import Application, Specviz
from jdaviz.configs.default.plugins.gaussian_smooth.gaussian_smooth import GaussianSmooth

Expand Down Expand Up @@ -170,3 +171,24 @@ def test_viewer_renaming_imviz(imviz_helper):
old_reference='non-existent',
new_reference='this-is-forbidden'
)


def test_data_associations(imviz_helper):
shape = (10, 10)

data_parent = np.ones(shape, dtype=float)
data_child = np.zeros(shape, dtype=int)

imviz_helper.load_data(data_parent, data_label='parent_data')
imviz_helper.load_data(data_child, data_label='child_data', parent='parent_data')

assert imviz_helper.app._get_assoc_data_children('parent_data') == ['child_data']
assert imviz_helper.app._get_assoc_data_parent('child_data') == 'parent_data'

with pytest.raises(NotImplementedError):
# we don't (yet) allow children of children:
imviz_helper.load_data(data_child, data_label='grandchild_data', parent='child_data')

with pytest.raises(ValueError):
# ensure the parent actually exists:
imviz_helper.load_data(data_child, data_label='child_data', parent='absent parent')

0 comments on commit 9bb02fc

Please sign in to comment.