diff --git a/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py b/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py index cf75f18..31756dd 100644 --- a/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py +++ b/tvbwidgets/ui/connectivity_ipy/connectivity_widget.py @@ -45,6 +45,7 @@ def __init__(self, **kwargs): self.__draw_connectivity() self.__show_plot() + CONTEXT.observe(lambda *args: self.__show_plot(), 'connectivity') def add_datatype(self, datatype): # type: (HasTraits) -> None pass @@ -104,7 +105,6 @@ def on_ctx_change(value): class Connectivity3DViewer(ipywidgets.VBox): - PYVISTA = 'PyVista' def __init__(self, **kwargs): self.output = PyVistaOutput() @@ -112,8 +112,10 @@ def __init__(self, **kwargs): super(Connectivity3DViewer, self).__init__([self.output], *kwargs) self.init_view_connectivity() + CONTEXT.observe(lambda *args: self.init_view_connectivity(), 'connectivity') def init_view_connectivity(self): + self.output.plotter.clear() points, edges, labels = self.add_actors() points_toggle, edges_toggle, labels_toggle = self._init_controls() if not labels_toggle.value: diff --git a/tvbwidgets/ui/connectivity_ipy/global_context.py b/tvbwidgets/ui/connectivity_ipy/global_context.py index 42c08a2..74767e6 100644 --- a/tvbwidgets/ui/connectivity_ipy/global_context.py +++ b/tvbwidgets/ui/connectivity_ipy/global_context.py @@ -19,34 +19,34 @@ def __call__(cls, *args, **kwargs): class GlobalContext(metaclass=SingletonMeta): MATRIX_OPTIONS = [('Tracts', 'tracts'), ('Weights', 'weights')] - _observed_state = dict() def __init__(self): - self._matrix = 'weights' - self._connectivity = None + self.__matrix = 'weights' + self.__connectivity = None + self.__observed_attributes = dict() self.connectivities_history = [] # list of connectivities previously used @property def matrix(self): - return self._matrix + return self.__matrix @matrix.setter def matrix(self, next_value): - prev_value = self._matrix - self._matrix = next_value + prev_value = self.__matrix + self.__matrix = next_value if prev_value != next_value: self.__notify_observers('matrix', next_value) @property def connectivity(self): # type: () -> Connectivity - return self._connectivity + return self.__connectivity @connectivity.setter def connectivity(self, next_value): # type: (Connectivity) -> None - previous = self._connectivity - self._connectivity = next_value + previous = self.__connectivity + self.__connectivity = next_value if previous != next_value: self.__notify_observers('connectivity', next_value) if not any([conn.gid == next_value.gid for conn in self.connectivities_history]): @@ -59,13 +59,13 @@ def __notify_observers(self, observed_attribute, next_value): passing as argument the next value for the attribute """ try: - observers = self._observed_state[observed_attribute] + observers = self.__observed_attributes[observed_attribute] for obs in observers: obs(next_value) except KeyError: pass - def observe(self, observer_func, value_observed): + def observe(self, observer_func, attribute_observed): # type: (Callable[[any], any], str) -> None """ Method to register an observer for the specified value. @@ -74,10 +74,10 @@ def observe(self, observer_func, value_observed): with the new value passed as param. """ try: - observers_list = self._observed_state[value_observed] + observers_list = self.__observed_attributes[attribute_observed] observers_list.append(observer_func) except KeyError: - self._observed_state[value_observed] = [observer_func] + self.__observed_attributes[attribute_observed] = [observer_func] def remove_observer(self, observer_func, value_observed): # type: (Callable[[any], any], str) -> None @@ -85,7 +85,7 @@ def remove_observer(self, observer_func, value_observed): Unregister a registered observer. """ try: - observers_list = self._observed_state[value_observed] + observers_list = self.__observed_attributes[value_observed] observers_list.remove(observer_func) except KeyError: pass diff --git a/tvbwidgets/ui/connectivity_ipy/operations.py b/tvbwidgets/ui/connectivity_ipy/operations.py index f9f441c..acd4f58 100644 --- a/tvbwidgets/ui/connectivity_ipy/operations.py +++ b/tvbwidgets/ui/connectivity_ipy/operations.py @@ -81,7 +81,6 @@ def on_ctx_change(value): @property def selected_regions(self): - print(self.regions_checkboxes) return list(map(lambda x: x.description, filter(lambda x: x.value, self.regions_checkboxes))) def __get_operations_buttons(self): @@ -93,22 +92,29 @@ def __get_operations_buttons(self): Create a new connectivity cutting the edges of selected nodes. Check the selected nodes in the above dropdown to see what it is included """ - cut_selected = ipywidgets.Button(description='Cut selected regions', - disabled=False, - button_style='success', - tooltip=cut_selected_tooltip, - icon='scissors') - - cut_edges_of_selected = ipywidgets.Button(description='Cut edges of selected', - disabled=False, - button_style='warning', - tooltip=cut_edges_tooltip, - icon='scissors') - - cut_selected.on_click(lambda *args: self.__cut_selected_nodes()) - cut_edges_of_selected.on_click(lambda *args: self.__cut_selected_edges()) - - return ipywidgets.HBox(children=[cut_selected, cut_edges_of_selected]) + cut_unselected_nodes = ipywidgets.Button(description='Cut selected regions', + disabled=False, + button_style='success', + tooltip=cut_selected_tooltip, + icon='scissors') + + cut_unselected_edges = ipywidgets.Button(description='Cut unselected edges', + disabled=False, + button_style='warning', + tooltip=cut_edges_tooltip, + icon='scissors') + + cut_selected_edges = ipywidgets.Button(description='Cut selected edges', + disabled=False, + button_style='warning', + tooltip=cut_edges_tooltip, + icon='scissors') + + cut_unselected_nodes.on_click(lambda *args: self.__cut_selected_nodes()) + cut_unselected_edges.on_click(lambda *args: self.__cut_edges()) + cut_selected_edges.on_click(lambda *args: self.__cut_edges(selected=True)) + + return ipywidgets.HBox(children=[cut_unselected_nodes, cut_unselected_edges, cut_selected_edges]) def __cut_selected_nodes(self): """ @@ -120,16 +126,18 @@ def __cut_selected_nodes(self): regions = CONTEXT.connectivity.region_labels selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions] new_conn = self._cut_connectivity(CONTEXT.connectivity, new_weights, selected_regions, new_tracts) + print('centres of new conn: ', new_conn.centres) CONTEXT.connectivity = new_conn - def __cut_selected_edges(self): + def __cut_edges(self, selected=False): print('cutting edges: ', self.selected_regions) matrix = CONTEXT.matrix new_weights = CONTEXT.connectivity.weights new_tracts = CONTEXT.connectivity.tract_lengths if matrix == 'tracts' else None regions = CONTEXT.connectivity.region_labels selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions] - new_conn = self._branch_connectivity(CONTEXT.connectivity, new_weights, selected_regions, new_tracts) + new_conn = self._cut_connectivity_edges(CONTEXT.connectivity, new_weights, selected_regions, new_tracts, + selected) CONTEXT.connectivity = new_conn def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts=None): @@ -147,8 +155,8 @@ def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts return new_weights, interest_areas, new_tracts - def _branch_connectivity(self, original_conn, new_weights, interest_areas, - new_tracts=None): + def _cut_connectivity_edges(self, original_conn, new_weights, interest_areas, + new_tracts=None, selected=False): # type: (Connectivity, numpy.array, numpy.array, numpy.array) -> Connectivity """ Generate new Connectivity based on a previous one, by changing weights (e.g. simulate lesion). @@ -157,8 +165,8 @@ def _branch_connectivity(self, original_conn, new_weights, interest_areas, :param new_weights: weights matrix for the new connectivity :param interest_areas: ndarray of the selected node id's :param new_tracts: tracts matrix for the new connectivity + :param selected: if true cuts out edges of selected areas else unselected edges """ - print('herre') new_weights, interest_areas, new_tracts = self._reorder_arrays(original_conn, new_weights, interest_areas, new_tracts) if new_tracts is None: @@ -167,6 +175,9 @@ def _branch_connectivity(self, original_conn, new_weights, interest_areas, for i in range(len(original_conn.weights)): for j in range(len(original_conn.weights)): if i not in interest_areas or j not in interest_areas: + if not selected: + new_weights[i][j] = 0 + elif i in interest_areas or j in interest_areas and selected: new_weights[i][j] = 0 final_conn = Connectivity()