Skip to content

Commit

Permalink
working single-bit masking
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Mar 26, 2024
1 parent 09ab7e2 commit 0582952
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 96 deletions.
88 changes: 64 additions & 24 deletions jdaviz/configs/default/plugins/data_quality/data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,19 @@ def init_decoding(self, event={}):
]
dq_layer.composite._allow_bad_alpha = True

flag_bits = np.float32([flag['flag'] for flag in self.decoded_flags])
flag_bits = np.array([flag['flag'] for flag in self.decoded_flags])

with delay_callback(dq_layer.state, 'alpha', 'cmap', 'v_min', 'v_max', 'stretch'):
dq_layer.state.cmap = cmap
dq_layer.state.stretch = 'lookup'
stretch_object = dq_layer.state.stretch_object
stretch_object.flags = flag_bits

dq_layer.state.stretch = 'lookup'
stretch_object = dq_layer.state.stretch_object
stretch_object.flags = flag_bits
with delay_callback(dq_layer.state, 'alpha', 'cmap', 'v_min', 'v_max'):
if len(flag_bits):
dq_layer.state.v_min = min(flag_bits)
dq_layer.state.v_max = max(flag_bits)

dq_layer.state.v_min = min(flag_bits)
dq_layer.state.v_max = max(flag_bits)
dq_layer.state.alpha = 0.9
dq_layer.state.cmap = cmap

@observe('decoded_flags')
def update_cmap(self, event={}):
Expand All @@ -167,34 +168,73 @@ def update_cmap(self, event={}):
layer for layer in viewer.layers if
layer.layer.label == self.dq_layer_selected
]
flag_bits = np.float32([flag['flag'] for flag in self.decoded_flags])
flag_bits = np.array([flag['flag'] for flag in self.decoded_flags])
rgb_colors = [hex2color(flag['color']) for flag in self.decoded_flags]
hidden_flags = np.array([flag['flag'] for flag in self.decoded_flags if not flag['show']])

# update the colors of the listed colormap without
# reassigning the layer.state.cmap object
cmap = dq_layer.state.cmap
cmap.colors = rgb_colors
cmap._init()
with delay_callback(dq_layer.state, 'v_min', 'v_max', 'alpha', 'stretch', 'cmap'):
# set correct stretch and limits:
# dq_layer.state.stretch = 'lookup'
stretch_object = dq_layer.state.stretch_object
stretch_object.flags = flag_bits
stretch_object.dq_array = dq_layer.get_image_data()
stretch_object.hidden_flags = hidden_flags

# update the colors of the listed colormap without
# reassigning the layer.state.cmap object
cmap = dq_layer.state.cmap
cmap.colors = rgb_colors
cmap._init()

with delay_callback(dq_layer.state, 'v_min', 'v_max', 'alpha'):
# trigger updates to cmap in viewer:
dq_layer.update()

# set correct stretch and limits:
dq_layer.state.stretch = 'lookup'
dq_layer.state.v_min = min(flag_bits)
dq_layer.state.v_max = max(flag_bits)
if len(flag_bits):
dq_layer.state.v_min = min(flag_bits)
dq_layer.state.v_max = max(flag_bits)

dq_layer.state.alpha = 0.9

def update_visibility(self, index):
self.decoded_flags[index]['show'] = not self.decoded_flags[index]['show']
self.send_state('decoded_flags')
self.update_cmap()

def vue_update_visibility(self, index):
self.update_visibility(index)

def update_color(self, index, color):
self.decoded_flags[index]['color'] = color
self.send_state('decoded_flags')
self.update_cmap()

def vue_update_color(self, args):
index, color = args
self.update_color(index, color)

@observe('science_layer_selected')
def mission_or_instrument_from_meta(self, event):
if not hasattr(self, 'science_layer'):
return

layer = self.science_layer.selected_obj
if len(layer):
# this is defined for JWST and ROMAN, should be upper case:
telescope = layer[0].layer.meta.get('telescope', None)
if not len(layer):
return

# this is defined for JWST and ROMAN, should be upper case:
telescope = layer[0].layer.meta.get('telescope', None)

if telescope is not None:
self.flag_map_selected = telescope_names[telescope.lower()]

def vue_hide_all_flags(self, event):
for flag in self.decoded_flags:
flag['show'] = False
self.send_state('decoded_flags')
self.update_cmap()

if telescope is not None:
self.flag_map_selected = telescope_names[telescope.lower()]
def vue_show_all_flags(self, event):
for flag in self.decoded_flags:
flag['show'] = True
self.send_state('decoded_flags')
self.update_cmap()
77 changes: 63 additions & 14 deletions jdaviz/configs/default/plugins/data_quality/data_quality.vue
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,50 @@
</v-row>

<j-plugin-section-header>Quality Flags</j-plugin-section-header>
<v-row class="row-no-padding">
<v-col cols=6>
<j-tooltip tipid='plugin-line-lists-erase-all'>
<v-btn
tile
:elevation=0
x-small
dense
color="turquoise"
dark
style="padding-left: 8px; padding-right: 6px;"
@click="show_all_flags">
<v-icon left small dense style="margin-right: 2px">mdi-eye</v-icon>
Show All
</v-btn>
</j-tooltip>
</v-col>
<v-col cols=6 style="text-align: right">
<j-tooltip tipid='plugin-line-lists-plot-all'>
<v-btn
tile
:elevation=0
x-small
dense
color="turquoise"
dark
style="padding-left: 8px; padding-right: 6px;"
@click="hide_all_flags">
<v-icon left small dense style="margin-right: 2px">mdi-eye-off</v-icon>
Hide All
</v-btn>
</j-tooltip>
</v-col>
</v-row>

<v-row style="max-width: calc(100% - 80px)">
<v-col>
Color
</v-col>
<v-col>
<strong>Flag</strong>
</v-col>
<v-col>
(Decomposed)
<v-col cols=8>
<strong>Flag</strong> (Decomposed)
</v-col>
</v-row>

<v-row>
<v-expansion-panels accordion>
<v-expansion-panel v-for="(item, index) in decoded_flags" key=":item">
Expand All @@ -78,21 +111,30 @@
</template>
<div @click.stop="" style="text-align: end; background-color: white">
<v-color-picker v-model="decoded_flags[index].color"
@update:color="throttledSetColor($event.hexa)">
@update:color="throttledSetColor(index, $event.hexa)">
></v-color-picker>
</div>
</v-menu>
</j-tooltip>
</v-col>
<v-col>
<div> <strong>{{item.flag}}</strong> ({{Object.keys(item.decomposed).join(', ')}})</div>
<v-col cols=8>
<div><strong>{{item.flag}}</strong> ({{Object.keys(item.decomposed).join(', ')}})</div>
</v-col>
</v-row>
</v-expansion-panel-header>
<v-expansion-panel-content>
<v-col v-for="(item, key, index) in item.decomposed">
<span>{{item.name}} ({{key}}): {{item.description}}</span>
<v-row no-gutters style="..." align="center">
<v-col cols=2 align="left">
<v-btn :color="item.show ? 'accent' : 'default'" icon @click="toggle_visibility(index)">
<v-icon>{{item.show ? "mdi-eye" : "mdi-eye-off"}}</v-icon>
</v-btn>
</v-col>
<v-col cols=8 align="left" style="...">
<v-row v-for="(item, key, index) in item.decomposed">
<span><strong>{{item.name}}</strong> ({{key}}): {{item.description}}</span>
</v-row>
</v-col>
</v-row>
</v-expansion-panel-content>
<v-expansion-panel>
</v-expansion-panels>
Expand All @@ -103,12 +145,19 @@

<script>
module.exports = {
created() {
this.throttledSetColor = _.throttle(
(v) => { this.color = v },
100);
created() {
this.throttledSetColor = _.throttle(
(index, color) => {
this.update_color([index, color])
},
100);
},
methods: {
toggle_visibility(index) {
this.update_visibility(index)
}
}
}
</script>

<style>
Expand Down
98 changes: 91 additions & 7 deletions jdaviz/configs/default/plugins/data_quality/dq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from matplotlib.colors import ListedColormap, rgb2hex
from glue.config import stretches
from astropy.table import Table

# paths to CSV files with DQ flag mappings:
Expand All @@ -12,6 +13,85 @@
}


class LookupStretch:
"""
Stretch class specific to DQ arrays.
Attributes
----------
flags : array-like
DQ flags.
"""

def __init__(self, flags=None, hidden_flags=None):
# Default x, y values(0-1) range chosen for a typical initial spline shape.
# Can be modified if required.
print('initializing stretch')
if flags is None:
flags = np.linspace(0, 1, 5)
if hidden_flags is None:
hidden_flags = []

self.flags = np.asarray(flags)
self.hidden_flags = np.asarray(hidden_flags).astype(int)

@property
def flag_range(self):
return np.max(self.flags) - np.min(self.flags)

@property
def scaled_flags(self):
# renormalize the flags on range (0, 1):
return (self.flags - np.min(self.flags)) / self.flag_range

def dq_array_to_flag_index(self, values):
# Find the index of the closest entry in `scaled_flags`
# for each of `values` using array broadcasting:
return np.argmin(
np.abs(
np.nan_to_num(values, nan=-10).flatten()[None, :] -
self.scaled_flags[:, None]
), axis=0
).reshape(values.shape)

def __call__(self, values, out=None, clip=False):
# For our uses, we can ignore `out` and `clip`, but those would need
# to be implemented before contributing this class upstream.

# find closest index in `self.flags` for each value in `values`:
if hasattr(values, 'squeeze'):
values = values.squeeze()

# `values` will have already been passed through
# astropy.visualization.ManualInterval and normalized on (0, 1)
# before they arrive here. First, remove that interval and get
# back the integer values:
values_integer = np.round(values * self.flag_range + np.min(self.flags))

# normalize by the number of flags, onto interval (0, 1):
renormed = self.dq_array_to_flag_index(values) / len(self.flags)

if len(self.hidden_flags):
# mask that is True for `values` in the hidden flags list:
value_is_hidden = np.isin(
np.nan_to_num(values_integer, nan=-10),
self.hidden_flags
)
else:
value_is_hidden = False

# preserve NaNs in values, and make hidden flags NaNs:
return np.where(
np.isnan(values) | value_is_hidden,
np.nan,
renormed
)


if "lookup" not in stretches:
stretches.add("lookup", LookupStretch, display="DQ")


def load_flag_map(mission_or_instrument=None, path=None):
"""
Load a flag map from disk.
Expand Down Expand Up @@ -77,7 +157,7 @@ def write_flag_map(flag_mapping, csv_path, **kwargs):
table.write(csv_path, format='ascii.csv', **kwargs)


def generate_listed_colormap(n_flags=None, seed=3):
def generate_listed_colormap(n_flags=None, rgba_colors=None, seed=3):
"""
Generate a list of random "light" colors of length ``n_flags``.
Expand All @@ -86,6 +166,8 @@ def generate_listed_colormap(n_flags=None, seed=3):
n_flags : int
Number of colors in the listed colormap, should match the
number of unique DQ flags (before they're decomposed).
rgba_colors : list of tuples
List of RGBA tuples for each color in the colormap.
seed : int
Seed for the random number generator used to
draw random colors.
Expand All @@ -100,12 +182,13 @@ def generate_listed_colormap(n_flags=None, seed=3):
rng = np.random.default_rng(seed)
default_alpha = 1

# Generate random colors that are generally "light", i.e. with
# RGB values in the upper half of the interval (0, 1):
rgba_colors = [
tuple(np.insert(rng.uniform(size=2), rng.integers(0, 3), 1).tolist() + [default_alpha])
for _ in range(n_flags)
]
if rgba_colors is None:
# Generate random colors that are generally "light", i.e. with
# RGB values in the upper half of the interval (0, 1):
rgba_colors = [
tuple(np.insert(rng.uniform(size=2), rng.integers(0, 3), 1).tolist() + [default_alpha])
for _ in range(n_flags)
]

cmap = ListedColormap(rgba_colors)

Expand Down Expand Up @@ -162,6 +245,7 @@ def decode_flags(flag_map, unique_flags, rgba_colors):
'flag': int(bit),
'decomposed': {bit: flag_map[bit] for bit in decoded_bits},
'color': rgb2hex(color),
'show': True,
})

return decoded_flags
Loading

0 comments on commit 0582952

Please sign in to comment.