From 4d0f673d02503ef79785ce5a4e56c897e8340472 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Mon, 12 Aug 2024 14:11:17 +0100 Subject: [PATCH] Initial work on supporting units in the scatter viewer --- glue/viewers/scatter/state.py | 40 ++++- glue/viewers/scatter/tests/test_viewer.py | 175 +++++++++++++++++++++- glue/viewers/scatter/viewer.py | 30 +++- 3 files changed, 239 insertions(+), 6 deletions(-) diff --git a/glue/viewers/scatter/state.py b/glue/viewers/scatter/state.py index 9fb0095f4..ee341688b 100644 --- a/glue/viewers/scatter/state.py +++ b/glue/viewers/scatter/state.py @@ -2,7 +2,7 @@ import numpy as np -from glue.core import BaseData, Subset +from glue.core import BaseData, Subset, Data from glue.config import colormaps from glue.viewers.matplotlib.state import (MatplotlibDataViewerState, @@ -14,6 +14,7 @@ from glue.core.data_combo_helper import ComponentIDComboHelper, ComboHelper from glue.core.exceptions import IncompatibleAttribute from glue.viewers.common.stretch_state_mixin import StretchStateMixin +from glue.core.units import find_unit_choices, UnitConverter from matplotlib.projections import get_projection_names @@ -34,6 +35,9 @@ class ScatterViewerState(MatplotlibDataViewerState): x_limits_percentile = DDCProperty(100, docstring="Percentile to use when automatically determining x limits") y_limits_percentile = DDCProperty(100, docstring="Percentile to use when automatically determining y limits") + x_display_unit = DDSCProperty(docstring='The units to use to display the x-axis.') + y_display_unit = DDSCProperty(docstring='The units to use to display the y-axis') + def __init__(self, **kwargs): super(ScatterViewerState, self).__init__() @@ -43,11 +47,13 @@ def __init__(self, **kwargs): self.x_lim_helper = StateAttributeLimitsHelper(self, attribute='x_att', lower='x_min', upper='x_max', log='x_log', margin=0.04, + display_units='x_display_unit', limits_cache=self.limits_cache) self.y_lim_helper = StateAttributeLimitsHelper(self, attribute='y_att', lower='y_min', upper='y_max', log='y_log', margin=0.04, + display_units='y_display_unit', limits_cache=self.limits_cache) self.add_callback('layers', self._layers_changed) @@ -68,6 +74,9 @@ def __init__(self, **kwargs): self.add_callback('x_log', self._reset_x_limits) self.add_callback('y_log', self._reset_y_limits) + self.add_callback('x_att', self._update_x_display_unit_choices) + self.add_callback('y_att', self._update_y_display_unit_choices) + if self.using_polar: self.full_circle() @@ -197,6 +206,35 @@ def _layers_changed(self, *args): self._layers_data_cache = layers_data + def _update_x_display_unit_choices(self, *args): + + # NOTE: only Data and its subclasses support specifying units + if self.x_att is None or not isinstance(self.x_att.parent, Data): + ScatterViewerState.x_display_unit.set_choices(self, []) + return + + component = self.x_att.parent.get_component(self.x_att) + if component.units: + x_choices = find_unit_choices([(self.x_att.parent, self.x_att, component.units)]) + else: + x_choices = [''] + ScatterViewerState.x_display_unit.set_choices(self, x_choices) + self.x_display_unit = component.units + + def _update_y_display_unit_choices(self, *args): + + # NOTE: only Data and its subclasses support specifying units + if self.y_att is None or not isinstance(self.y_att.parent, Data): + ScatterViewerState.y_display_unit.set_choices(self, []) + return + + component = self.y_att.parent.get_component(self.y_att) + if component.units: + y_choices = find_unit_choices([(self.y_att.parent, self.y_att, component.units)]) + else: + y_choices = [''] + ScatterViewerState.y_display_unit.set_choices(self, y_choices) + self.y_display_unit = component.units def display_func_slow(x): if x == 'Linear': diff --git a/glue/viewers/scatter/tests/test_viewer.py b/glue/viewers/scatter/tests/test_viewer.py index 0e6e5df52..dc4ad52b4 100644 --- a/glue/viewers/scatter/tests/test_viewer.py +++ b/glue/viewers/scatter/tests/test_viewer.py @@ -1,5 +1,5 @@ import numpy as np -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_equal import matplotlib.pyplot as plt @@ -8,8 +8,9 @@ from glue.viewers.scatter.viewer import SimpleScatterViewer from glue.core.application_base import Application from glue.core.data import Data -from glue.core.link_helpers import LinkSame +from glue.core.link_helpers import LinkSame, LinkSameWithUnits from glue.core.data_derived import IndexedData +from glue.core.roi import RectangularROI @visual_test @@ -131,3 +132,173 @@ def test_indexed_data(): assert viewer.state.x_att is data_2d.main_components[0] assert viewer.state.y_att is data_2d.main_components[1] + + +def test_unit_conversion(): + + d1 = Data(a=[1, 2, 3], b=[2, 3, 4]) + d1.get_component('a').units = 'm' + d1.get_component('b').units = 's' + + d2 = Data(c=[2000, 1000, 3000], d=[0.001, 0.002, 0.004]) + d2.get_component('c').units = 'mm' + d2.get_component('d').units = 'ks' + + # d3 is the same as d2 but we will link it differently + d3 = Data(e=[2000, 1000, 3000], f=[0.001, 0.002, 0.004]) + d3.get_component('e').units = 'mm' + d3.get_component('f').units = 'ks' + + d4 = Data(g=[2, 2, 3], h=[1, 2, 1]) + d4.get_component('g').units = 'kg' + d4.get_component('h').units = 'm/s' + + app = Application() + session = app.session + + data_collection = session.data_collection + data_collection.append(d1) + data_collection.append(d2) + data_collection.append(d3) + data_collection.append(d4) + + data_collection.add_link(LinkSameWithUnits(d1.id['a'], d2.id['c'])) + data_collection.add_link(LinkSameWithUnits(d1.id['b'], d2.id['d'])) + data_collection.add_link(LinkSame(d1.id['a'], d3.id['e'])) + data_collection.add_link(LinkSame(d1.id['b'], d3.id['f'])) + data_collection.add_link(LinkSame(d1.id['a'], d4.id['g'])) + data_collection.add_link(LinkSame(d1.id['b'], d4.id['h'])) + + viewer = app.new_data_viewer(SimpleScatterViewer) + viewer.add_data(d1) + viewer.add_data(d2) + viewer.add_data(d3) + viewer.add_data(d4) + + assert viewer.layers[0].enabled + assert viewer.layers[1].enabled + assert viewer.layers[2].enabled + assert viewer.layers[3].enabled + + assert viewer.state.x_min == 0.92 + assert viewer.state.x_max == 3.08 + assert viewer.state.y_min == 1.92 + assert viewer.state.y_max == 4.08 + + roi = RectangularROI(0.5, 2.5, 1.5, 4.5) + viewer.apply_roi(roi) + + assert len(d1.subsets) == 1 + assert_equal(d1.subsets[0].to_mask(), [1, 1, 0]) + + # Because of the LinkSameWithUnits, the points actually appear in the right + # place even before we set the display units. + assert len(d2.subsets) == 1 + assert_equal(d2.subsets[0].to_mask(), [0, 1, 0]) + + # d3 is only linked with LinkSame not LinkSameWithUnits so currently the + # points are outside the visible axes + assert len(d3.subsets) == 1 + assert_equal(d3.subsets[0].to_mask(), [0, 0, 0]) + + # As we haven't set display units yet, the values for this dataset are shown + # on the same scale as for d1 as if the units had never been set. + assert len(d4.subsets) == 1 + assert_equal(d4.subsets[0].to_mask(), [0, 1, 0]) + + # Now try setting the units explicitly + + viewer.state.x_display_unit = 'km' + viewer.state.y_display_unit = 'ms' + + assert_allclose(viewer.state.x_min, 0.92e-3) + assert_allclose(viewer.state.x_max, 3.08e-3) + assert_allclose(viewer.state.y_min, 1.92e3) + assert_allclose(viewer.state.y_max, 4.08e3) + + roi = RectangularROI(0.5e-3, 2.5e-3, 1.5e3, 4.5e3) + viewer.apply_roi(roi) + + # d1 and d2 will be as above, but d3 will now work correctly while d4 should + # not be shown. + + assert_equal(d1.subsets[1].to_mask(), [1, 1, 0]) + assert_equal(d2.subsets[1].to_mask(), [0, 1, 0]) + assert_equal(d3.subsets[1].to_mask(), [0, 0, 0]) + assert_equal(d4.subsets[1].to_mask(), [0, 1, 0]) + + + # # Change the limits to make sure they are always converted + # viewer.state.x_min = 5e8 + # viewer.state.x_max = 4e9 + # viewer.state.y_min = 0.5 + # viewer.state.y_max = 3.5 + + # roi = XRangeROI(1.4e9, 2.1e9) + # viewer.apply_roi(roi) + + # assert len(d1.subsets) == 1 + # assert_equal(d1.subsets[0].to_mask(), [0, 1, 0]) + + # assert len(d2.subsets) == 1 + # assert_equal(d2.subsets[0].to_mask(), [0, 1, 0]) + + # viewer.state.x_display_unit = 'GHz' + # viewer.state.y_display_unit = 'mJy' + + # x, y = viewer.state.layers[0].profile + # assert_allclose(x, [1, 2, 3]) + # assert_allclose(y, [1000, 2000, 3000]) + + # x, y = viewer.state.layers[1].profile + # assert_allclose(x, 2.99792458 / np.array([1, 2, 3])) + # assert_allclose(y, [2000, 1000, 3000]) + + # assert viewer.state.x_min == 0.5 + # assert viewer.state.x_max == 4. + + # # Units get reset because they were originally 'native' and 'native' to a + # # specific unit always trigger resetting the limits since different datasets + # # might be converted in different ways. + # assert viewer.state.y_min == 1000. + # assert viewer.state.y_max == 3000. + + # # Now set the limits explicitly again and make sure in future they are converted + # viewer.state.y_min = 500. + # viewer.state.y_max = 3500. + + # roi = XRangeROI(0.5, 1.2) + # viewer.apply_roi(roi) + + # assert len(d1.subsets) == 1 + # assert_equal(d1.subsets[0].to_mask(), [1, 0, 0]) + + # assert len(d2.subsets) == 1 + # assert_equal(d2.subsets[0].to_mask(), [0, 0, 1]) + + # viewer.state.x_display_unit = 'cm' + # viewer.state.y_display_unit = 'Jy' + + # roi = XRangeROI(15, 35) + # viewer.apply_roi(roi) + + # assert len(d1.subsets) == 1 + # assert_equal(d1.subsets[0].to_mask(), [1, 0, 0]) + + # assert len(d2.subsets) == 1 + # assert_equal(d2.subsets[0].to_mask(), [0, 1, 1]) + + # assert_allclose(viewer.state.x_min, (4 * u.GHz).to_value(u.cm, equivalencies=u.spectral())) + # assert_allclose(viewer.state.x_max, (0.5 * u.GHz).to_value(u.cm, equivalencies=u.spectral())) + # assert_allclose(viewer.state.y_min, 0.5) + # assert_allclose(viewer.state.y_max, 3.5) + + # # Regression test for a bug that caused unit changes to not work on y axis + # # if reference data was not first layer + + # viewer.state.reference_data = d2 + # viewer.state.y_display_unit = 'mJy' + + + # data_collection.add_link(LinkSame(d1.id['a'], d2.id['e'])) + # data_collection.add_link(LinkSame(d1.id['b'], d2.id['f'])) diff --git a/glue/viewers/scatter/viewer.py b/glue/viewers/scatter/viewer.py index c986fd706..00707fefd 100644 --- a/glue/viewers/scatter/viewer.py +++ b/glue/viewers/scatter/viewer.py @@ -8,6 +8,8 @@ from glue.viewers.matplotlib.viewer import SimpleMatplotlibViewer from glue.viewers.scatter.state import ScatterViewerState from glue.viewers.scatter.layer_artist import ScatterLayerArtist +from glue.core.units import UnitConverter + __all__ = ['MatplotlibScatterMixin', 'SimpleScatterViewer'] @@ -152,9 +154,31 @@ def apply_roi(self, roi, override_mode=None): x_date = 'datetime' in self.state.x_kinds y_date = 'datetime' in self.state.y_kinds - if x_date or y_date: - roi = roi.transformed(xfunc=mpl_to_datetime64 if x_date else None, - yfunc=mpl_to_datetime64 if y_date else None) + converter = UnitConverter() + + xfunc = None + if x_date: + xfunc = mpl_to_datetime64 + else: + if self.state.x_display_unit: + xfunc = lambda x: converter.to_native(self.state.x_att.parent, + self.state.x_att, x, + self.state.x_display_unit) + + yfunc = None + if y_date: + yfunc = mpl_to_datetime64 + else: + if self.state.y_display_unit: + yfunc = lambda y: converter.to_native(self.state.y_att.parent, + self.state.y_att, y, + self.state.y_display_unit) + + print(xfunc) + print(yfunc) + + if xfunc or yfunc: + roi = roi.transformed(xfunc=xfunc, yfunc=yfunc) use_transform = not self.using_rectilinear() subset_state = roi_to_subset_state(roi,