Skip to content

Commit

Permalink
Merge pull request #67 from peeplika/WID-218
Browse files Browse the repository at this point in the history
[WID-218] Add controls for channel coloring - TS widget with plotly
  • Loading branch information
liadomide authored Apr 17, 2024
2 parents 35f18a4 + b51df6b commit fcc188d
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions tvbwidgets/ui/ts/plotly_ts_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.colors as mlt
from IPython.core.display_functions import display
from plotly_resampler import register_plotly_resampler, FigureWidgetResampler
from tvbwidgets.ui.ts.base_ts_widget import TimeSeriesWidgetBase
Expand All @@ -20,12 +22,14 @@ def __init__(self, **kwargs):
self.fig = None
self.data = None
self.ch_names = []
self.ch_picked = []
self.raw = None
self.sample_freq = 0
self.start_time = 0
self.end_time = 0
self.std_step = 0
self.amplitude = 1
self.colormap = None

# plot & UI
self.checkboxes = dict()
Expand All @@ -36,8 +40,15 @@ def __init__(self, **kwargs):
self.plot_area.children += (self.output,)
self.scaling_title = widgets.Label(value='Increase/Decrease signal scaling (current scaling value to the right)')
self.scaling_slider = widgets.IntSlider(value=1, layout=widgets.Layout(width='30%'))

super().__init__([self.plot_area, widgets.VBox([self.scaling_title, self.scaling_slider],
self.colormaps = ['turbo', 'brg', 'gist_stern_r', 'nipy_spectral_r', 'coolwarm','plasma', 'magma', 'viridis', \
'cividis', 'twilight', 'twilight_shifted', 'CMRmap_r', 'Blues', \
'BuGn', 'BuPu', 'Greens', 'PuRd', 'RdPu', 'Spectral', 'YlGnBu', \
'YlOrBr', 'YlOrRd', 'cubehelix_r', 'gist_earth_r', 'terrain_r', \
'rainbow_r', 'pink_r', 'gist_ncar_r', 'uni-color(black)']
self.colormap_dropdown = widgets.Dropdown(options=self.colormaps, description='Colormap:', disabled=False)
self.colormap_dropdown.observe(self.update_colormap, names='value')

super().__init__([self.plot_area, widgets.VBox([self.colormap_dropdown, self.scaling_title, self.scaling_slider],
layout=widgets.Layout(margin='0px 0px 0px 80px')),
self.info_and_channels_area],
layout=self.DEFAULT_BORDER)
Expand All @@ -46,6 +57,7 @@ def __init__(self, **kwargs):
# =========================================== SETUP ================================================================
def _populate_from_data_wrapper(self, data_wrapper):
super()._populate_from_data_wrapper(data_wrapper=data_wrapper)
self.ch_picked = list(range(len(self.ch_names)))
del self.ch_order, self.ch_types # delete these as we don't use them in plotly
# populate channel selection area
self.channels_area = self._create_channel_selection_area(array_wrapper=data_wrapper)
Expand All @@ -62,9 +74,16 @@ def add_traces_to_plot(self, data, ch_names):
# traces will be added from bottom to top, so reverse the lists to put the first channel on top
data = data[::-1]
ch_names = ch_names[::-1]
if self.colormap == "uni-color(black)":
colormap = plt.get_cmap('gray')
colors = colormap(np.linspace(0, 0, len(ch_names)))
else:
colormap = plt.get_cmap(self.colormap)
colors = colormap(np.linspace(0.3, 1, len(ch_names)))
colors = [mlt.to_hex(color, keep_alpha=False) for color in colors]

self.fig.add_traces(
[dict(y=ts * self.amplitude + i * self.std_step, name=ch_name, customdata=ts, hovertemplate='%{customdata}')
[dict(y=ts * self.amplitude + i * self.std_step, name=ch_name, customdata=ts, hovertemplate='%{customdata}', line_color = colors[i])
for i, (ch_name, ts) in enumerate(zip(ch_names, data))]
)

Expand Down Expand Up @@ -137,6 +156,14 @@ def plot_ts_with_plotly(self, data=None, ch_names=None):
self.output.clear_output(wait=True)
display(self.fig)

def update_colormap(self,change):
self.colormap = change['new']
self.fig.data = []
data = self.raw[:, :][0]
data = data[self.ch_picked, :]
ch_names = [self.ch_names[i] for i in self.ch_picked]
self.add_traces_to_plot(data, ch_names)

# ================================================= SCALING ========================================================
def _setup_scaling_slider(self):
# set min and max scaling values
Expand All @@ -152,8 +179,10 @@ def update_scaling(self, val):
# delete old traces
self.fig.data = []
data = self.raw[:, :][0]
data = data[self.ch_picked, :]
ch_names = [self.ch_names[i] for i in self.ch_picked]
self.add_traces_to_plot(data, ch_names)

self.add_traces_to_plot(data, self.ch_names)

# =========================================== CHANNELS SELECTION ===================================================
def _create_channel_selection_area(self, array_wrapper, no_checkbox_columns=5):
Expand Down Expand Up @@ -193,26 +222,26 @@ def _create_channel_selection_area(self, array_wrapper, no_checkbox_columns=5):
def _update_ts(self, btn):
self.logger.debug('Updating TS')
ch_names = list(self.ch_names)

# save selected channels using their index in the ch_names list
picks = []
self.ch_picked = []
for cb in list(self.checkboxes.values()):
ch_index = ch_names.index(cb.description) # get the channel index
if cb.value:
picks.append(ch_index) # list with number representation of channels
self.ch_picked.append(ch_index) # list with number representation of channels

# if unselect all
# TODO: should we remove just the traces and leave the channel names and the ticks??
if not picks:
if not self.ch_picked:
self.fig.data = [] # remove traces
self.fig.layout.annotations = [] # remove channel names
self.fig.layout.yaxis.tickvals = [] # remove ticks between channel names and traces
return

# get data and names for selected channels; self.raw is updated before redrawing starts
data, _ = self.raw[:, :]
data = data[picks, :]
ch_names = [ch_names[i] for i in picks]
data = data[self.ch_picked, :]
ch_names = [ch_names[i] for i in self.ch_picked]

# redraw the entire plot
self.plot_ts_with_plotly(data, ch_names)
Expand Down

0 comments on commit fcc188d

Please sign in to comment.