diff --git a/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py b/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py index 31756dd..b7ee63e 100644 --- a/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py +++ b/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py @@ -17,7 +17,7 @@ from tvbwidgets.ui.connectivity_ipy.outputs_3d import PyVistaOutput from tvbwidgets.ui.connectivity_ipy.operations import ConnectivityOperations from tvbwidgets.ui.connectivity_ipy.config import ConnectivityConfig -from tvbwidgets.ui.connectivity_ipy.global_context import CONTEXT +from tvbwidgets.ui.connectivity_ipy.global_context import CONTEXT, ObservableAttrs DROPDOWN_KEY = 'dropdown' @@ -45,7 +45,7 @@ def __init__(self, **kwargs): self.__draw_connectivity() self.__show_plot() - CONTEXT.observe(lambda *args: self.__show_plot(), 'connectivity') + CONTEXT.observe(lambda *args: self.__show_plot(), ObservableAttrs.CONNECTIVITY) def add_datatype(self, datatype): # type: (HasTraits) -> None pass @@ -99,7 +99,7 @@ def on_change(change): def on_ctx_change(value): dropdown.value = value - CONTEXT.observe(on_ctx_change, 'matrix') + CONTEXT.observe(on_ctx_change, ObservableAttrs.MATRIX) self.widgets_map[DROPDOWN_KEY] = dropdown self.children = (dropdown, *self.children) @@ -112,7 +112,7 @@ def __init__(self, **kwargs): super(Connectivity3DViewer, self).__init__([self.output], *kwargs) self.init_view_connectivity() - CONTEXT.observe(lambda *args: self.init_view_connectivity(), 'connectivity') + CONTEXT.observe(lambda *args: self.init_view_connectivity(), ObservableAttrs.CONNECTIVITY) def init_view_connectivity(self): self.output.plotter.clear() @@ -226,6 +226,19 @@ def add_datatype(self, datatype): """ pass + def get_connectivity(self, gid=None): + # type: (str|None) -> Connectivity + """ + Get a connectivity with the gid provided from the context history. + if gid=None return the current connectivity set on CONTEXT + """ + if gid is None: + return CONTEXT.connectivity + conn = list(filter(lambda c: c.gid.hex == gid, CONTEXT.connectivities_history)) + if not len(conn): + return None + return conn[0] + def __init__(self, connectivity, **kwargs): style = self.DEFAULT_BORDER super().__init__(**kwargs, layout=style) diff --git a/tvbwidgets/ui/connectivity_ipy/global_context.py b/tvbwidgets/ui/connectivity_ipy/global_context.py index 74767e6..7d4a0ba 100644 --- a/tvbwidgets/ui/connectivity_ipy/global_context.py +++ b/tvbwidgets/ui/connectivity_ipy/global_context.py @@ -4,10 +4,19 @@ # # (c) 2022-2023, TVB Widgets Team # +import enum from typing import Callable from tvb.datatypes.connectivity import Connectivity +class ObservableAttrs(str, enum.Enum): + """ + Enum representing observable attributes of the GlobalContext singleton + """ + MATRIX = 'matrix' + CONNECTIVITY = 'connectivity' + + class SingletonMeta(type): _instances = {} @@ -35,7 +44,7 @@ def matrix(self, next_value): prev_value = self.__matrix self.__matrix = next_value if prev_value != next_value: - self.__notify_observers('matrix', next_value) + self.__notify_observers(ObservableAttrs.MATRIX, next_value) @property def connectivity(self): @@ -47,10 +56,12 @@ def connectivity(self, next_value): # type: (Connectivity) -> None previous = self.__connectivity self.__connectivity = next_value - if previous != next_value: - self.__notify_observers('connectivity', next_value) + if not len(self.connectivities_history): + self.connectivities_history = [self.__connectivity] + if previous and previous.gid.hex != next_value.gid.hex: if not any([conn.gid == next_value.gid for conn in self.connectivities_history]): self.connectivities_history.append(next_value) + self.__notify_observers(ObservableAttrs.CONNECTIVITY, next_value) def __notify_observers(self, observed_attribute, next_value): # type: (str, any) -> None diff --git a/tvbwidgets/ui/connectivity_ipy/operations.py b/tvbwidgets/ui/connectivity_ipy/operations.py index b6180d2..988242a 100644 --- a/tvbwidgets/ui/connectivity_ipy/operations.py +++ b/tvbwidgets/ui/connectivity_ipy/operations.py @@ -39,8 +39,15 @@ def __init__(self, **kwargs): def update_history_ui(*args): children[1] = self.get_history_dropdown() self.children = children + self.send_state(self.keys) # trigger ui update + + def update_selector_ui(*args): + children[2] = self.__get_node_selector() + self.children = children + self.send_state(self.keys) # trigger ui update CONTEXT.observe(update_history_ui, 'connectivity') + CONTEXT.observe(update_selector_ui, 'connectivity') self.children = children @property @@ -73,24 +80,8 @@ def __get_node_selector(self): )), layout={'width': '50%', 'align-items': 'start'} ) - matrix_dropdown = ipywidgets.Dropdown( - options=CONTEXT.MATRIX_OPTIONS, - value=CONTEXT.matrix, - description='Matrix:' - ) - - def on_change(value): - CONTEXT.matrix = value['new'] - - matrix_dropdown.observe(on_change, 'value') - - def on_ctx_change(value): - matrix_dropdown.value = value - CONTEXT.observe(on_ctx_change, 'matrix') - - container = ipywidgets.VBox(children=(matrix_dropdown, - ipywidgets.HBox(children=(left, right)))) + container = ipywidgets.HBox(children=(left, right)) accordion = ipywidgets.Accordion(children=[container], selected_index=None, layout={'max-height': '50vh'}) accordion.set_title(0, 'Regions selector') @@ -133,22 +124,15 @@ def __cut_selected_nodes(self): """ Create a new connectivity using only the selected nodes """ - matrix = CONTEXT.matrix - new_weights = CONTEXT.connectivity.weights - new_tracts = CONTEXT.connectivity.tract_lengths if matrix == 'tracts' else None regions = CONTEXT.connectivity.region_labels selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions] - new_conn = self._cut_connectivity(CONTEXT.connectivity, new_weights, selected_regions, new_tracts) + new_conn = self._cut_connectivity(selected_regions) CONTEXT.connectivity = new_conn def __cut_edges(self, selected=False): - matrix = CONTEXT.matrix - new_weights = CONTEXT.connectivity.weights - new_tracts = CONTEXT.connectivity.tract_lengths if matrix == 'tracts' else None regions = CONTEXT.connectivity.region_labels selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions] - new_conn = self._cut_connectivity_edges(CONTEXT.connectivity, new_weights, selected_regions, new_tracts, - selected) + new_conn = self._cut_connectivity_edges(selected_regions, selected) CONTEXT.connectivity = new_conn def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts=None): @@ -168,33 +152,38 @@ def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts def get_history_dropdown(self): values = [(conn.gid.hex, conn) for conn in CONTEXT.connectivities_history] + default = len(values) and values[-1][1] or None dropdown = ipywidgets.Dropdown(options=values, description='View history', - disabled=False) - dropdown.observe(lambda x: print(x['new']), 'value') + disabled=False, + value=default) + + def on_connectivity_change(change): + CONTEXT.connectivity = change['new'] + + dropdown.observe(on_connectivity_change, 'value') return dropdown - def _cut_connectivity_edges(self, original_conn, new_weights, interest_areas, - new_tracts=None, selected=False): - # type: (Connectivity, numpy.array, numpy.array, numpy.array) -> Connectivity + def _cut_connectivity_edges(self, interest_areas, selected=False): + # type: (numpy.array, bool) -> Connectivity """ Generate new Connectivity based on a previous one, by changing weights (e.g. simulate lesion). The returned connectivity has the same number of nodes. The edges of unselected nodes will have weight 0. - :param original_conn: Original Connectivity, to copy from - :param new_weights: weights matrix for the new connectivity :param interest_areas: ndarray of the selected node id's - :param new_tracts: tracts matrix for the new connectivity :param selected: if true cuts out edges of selected areas else unselected edges """ + + original_conn = CONTEXT.connectivity + new_weights = CONTEXT.connectivity.weights + new_tracts = CONTEXT.connectivity.tract_lengths + if not len(interest_areas): - LOGGER.error('No intrest areas selected!') + LOGGER.error('No interest areas selected!') return CONTEXT.connectivity new_weights, interest_areas, new_tracts = self._reorder_arrays(original_conn, new_weights, interest_areas, new_tracts) - if new_tracts is None: - new_tracts = original_conn.tract_lengths for i in range(len(original_conn.weights)): for j in range(len(original_conn.weights)): @@ -218,28 +207,27 @@ def _cut_connectivity_edges(self, original_conn, new_weights, interest_areas, final_conn.configure() return final_conn - def _cut_connectivity(self, original_conn, new_weights, interest_areas, new_tracts=None): - # type: (Connectivity, numpy.array, numpy.array, numpy.array) -> Connectivity + def _cut_connectivity(self, interest_areas, selected=False): + # type: (numpy.array, bool) -> Connectivity """ Generate new Connectivity object based on current one, by removing nodes (e.g. simulate lesion). Only the selected nodes will get used in the result. The order of the indices in interest_areas matters. If indices are not sorted then the nodes will be permuted accordingly. - - :param original_conn: Original Connectivity(HasTraits), to cut nodes from - :param new_weights: weights matrix for the new connectivity :param interest_areas: ndarray with the selected node id's. - :param new_tracts: tracts matrix for the new connectivity """ + + original_conn = CONTEXT.connectivity + new_weights = CONTEXT.connectivity.weights + new_tracts = CONTEXT.connectivity.tract_lengths + if not len(interest_areas): LOGGER.error('No interest areas selected!') return CONTEXT.connectivity new_weights, interest_areas, new_tracts = self._reorder_arrays(original_conn, new_weights, interest_areas, new_tracts) - if new_tracts is None: - new_tracts = original_conn.tract_lengths[interest_areas, :][:, interest_areas] - else: - new_tracts = new_tracts[interest_areas, :][:, interest_areas] + + new_tracts = new_tracts[interest_areas, :][:, interest_areas] new_weights = new_weights[interest_areas, :][:, interest_areas] final_conn = Connectivity()