Skip to content

Commit

Permalink
WID-223: add cut selected/unselected edges, synchronize viewers with …
Browse files Browse the repository at this point in the history
…result
  • Loading branch information
davidbacter01 committed Sep 14, 2023
1 parent c048efa commit f9e5f7b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 37 deletions.
4 changes: 3 additions & 1 deletion tvbwidgets/ui/connectivity_ipy/connectivity_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, **kwargs):

self.__draw_connectivity()
self.__show_plot()
CONTEXT.observe(lambda *args: self.__show_plot(), 'connectivity')

def add_datatype(self, datatype): # type: (HasTraits) -> None
pass
Expand Down Expand Up @@ -104,16 +105,17 @@ def on_ctx_change(value):


class Connectivity3DViewer(ipywidgets.VBox):
PYVISTA = 'PyVista'

def __init__(self, **kwargs):
self.output = PyVistaOutput()

super(Connectivity3DViewer, self).__init__([self.output], *kwargs)

self.init_view_connectivity()
CONTEXT.observe(lambda *args: self.init_view_connectivity(), 'connectivity')

def init_view_connectivity(self):
self.output.plotter.clear()
points, edges, labels = self.add_actors()
points_toggle, edges_toggle, labels_toggle = self._init_controls()
if not labels_toggle.value:
Expand Down
28 changes: 14 additions & 14 deletions tvbwidgets/ui/connectivity_ipy/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,34 @@ def __call__(cls, *args, **kwargs):

class GlobalContext(metaclass=SingletonMeta):
MATRIX_OPTIONS = [('Tracts', 'tracts'), ('Weights', 'weights')]
_observed_state = dict()

def __init__(self):
self._matrix = 'weights'
self._connectivity = None
self.__matrix = 'weights'
self.__connectivity = None
self.__observed_attributes = dict()
self.connectivities_history = [] # list of connectivities previously used

@property
def matrix(self):
return self._matrix
return self.__matrix

@matrix.setter
def matrix(self, next_value):
prev_value = self._matrix
self._matrix = next_value
prev_value = self.__matrix
self.__matrix = next_value
if prev_value != next_value:
self.__notify_observers('matrix', next_value)

@property
def connectivity(self):
# type: () -> Connectivity
return self._connectivity
return self.__connectivity

@connectivity.setter
def connectivity(self, next_value):
# type: (Connectivity) -> None
previous = self._connectivity
self._connectivity = next_value
previous = self.__connectivity
self.__connectivity = next_value
if previous != next_value:
self.__notify_observers('connectivity', next_value)
if not any([conn.gid == next_value.gid for conn in self.connectivities_history]):
Expand All @@ -59,13 +59,13 @@ def __notify_observers(self, observed_attribute, next_value):
passing as argument the next value for the attribute
"""
try:
observers = self._observed_state[observed_attribute]
observers = self.__observed_attributes[observed_attribute]
for obs in observers:
obs(next_value)
except KeyError:
pass

def observe(self, observer_func, value_observed):
def observe(self, observer_func, attribute_observed):
# type: (Callable[[any], any], str) -> None
"""
Method to register an observer for the specified value.
Expand All @@ -74,18 +74,18 @@ def observe(self, observer_func, value_observed):
with the new value passed as param.
"""
try:
observers_list = self._observed_state[value_observed]
observers_list = self.__observed_attributes[attribute_observed]
observers_list.append(observer_func)
except KeyError:
self._observed_state[value_observed] = [observer_func]
self.__observed_attributes[attribute_observed] = [observer_func]

def remove_observer(self, observer_func, value_observed):
# type: (Callable[[any], any], str) -> None
"""
Unregister a registered observer.
"""
try:
observers_list = self._observed_state[value_observed]
observers_list = self.__observed_attributes[value_observed]
observers_list.remove(observer_func)
except KeyError:
pass
Expand Down
55 changes: 33 additions & 22 deletions tvbwidgets/ui/connectivity_ipy/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def on_ctx_change(value):

@property
def selected_regions(self):
print(self.regions_checkboxes)
return list(map(lambda x: x.description, filter(lambda x: x.value, self.regions_checkboxes)))

def __get_operations_buttons(self):
Expand All @@ -93,22 +92,29 @@ def __get_operations_buttons(self):
Create a new connectivity cutting the edges of selected nodes.
Check the selected nodes in the above dropdown to see what it is included
"""
cut_selected = ipywidgets.Button(description='Cut selected regions',
disabled=False,
button_style='success',
tooltip=cut_selected_tooltip,
icon='scissors')

cut_edges_of_selected = ipywidgets.Button(description='Cut edges of selected',
disabled=False,
button_style='warning',
tooltip=cut_edges_tooltip,
icon='scissors')

cut_selected.on_click(lambda *args: self.__cut_selected_nodes())
cut_edges_of_selected.on_click(lambda *args: self.__cut_selected_edges())

return ipywidgets.HBox(children=[cut_selected, cut_edges_of_selected])
cut_unselected_nodes = ipywidgets.Button(description='Cut selected regions',
disabled=False,
button_style='success',
tooltip=cut_selected_tooltip,
icon='scissors')

cut_unselected_edges = ipywidgets.Button(description='Cut unselected edges',
disabled=False,
button_style='warning',
tooltip=cut_edges_tooltip,
icon='scissors')

cut_selected_edges = ipywidgets.Button(description='Cut selected edges',
disabled=False,
button_style='warning',
tooltip=cut_edges_tooltip,
icon='scissors')

cut_unselected_nodes.on_click(lambda *args: self.__cut_selected_nodes())
cut_unselected_edges.on_click(lambda *args: self.__cut_edges())
cut_selected_edges.on_click(lambda *args: self.__cut_edges(selected=True))

return ipywidgets.HBox(children=[cut_unselected_nodes, cut_unselected_edges, cut_selected_edges])

def __cut_selected_nodes(self):
"""
Expand All @@ -120,16 +126,18 @@ def __cut_selected_nodes(self):
regions = CONTEXT.connectivity.region_labels
selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions]
new_conn = self._cut_connectivity(CONTEXT.connectivity, new_weights, selected_regions, new_tracts)
print('centres of new conn: ', new_conn.centres)
CONTEXT.connectivity = new_conn

def __cut_selected_edges(self):
def __cut_edges(self, selected=False):
print('cutting edges: ', self.selected_regions)
matrix = CONTEXT.matrix
new_weights = CONTEXT.connectivity.weights
new_tracts = CONTEXT.connectivity.tract_lengths if matrix == 'tracts' else None
regions = CONTEXT.connectivity.region_labels
selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions]
new_conn = self._branch_connectivity(CONTEXT.connectivity, new_weights, selected_regions, new_tracts)
new_conn = self._cut_connectivity_edges(CONTEXT.connectivity, new_weights, selected_regions, new_tracts,
selected)
CONTEXT.connectivity = new_conn

def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts=None):
Expand All @@ -147,8 +155,8 @@ def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts

return new_weights, interest_areas, new_tracts

def _branch_connectivity(self, original_conn, new_weights, interest_areas,
new_tracts=None):
def _cut_connectivity_edges(self, original_conn, new_weights, interest_areas,
new_tracts=None, selected=False):
# type: (Connectivity, numpy.array, numpy.array, numpy.array) -> Connectivity
"""
Generate new Connectivity based on a previous one, by changing weights (e.g. simulate lesion).
Expand All @@ -157,8 +165,8 @@ def _branch_connectivity(self, original_conn, new_weights, interest_areas,
:param new_weights: weights matrix for the new connectivity
:param interest_areas: ndarray of the selected node id's
:param new_tracts: tracts matrix for the new connectivity
:param selected: if true cuts out edges of selected areas else unselected edges
"""
print('herre')
new_weights, interest_areas, new_tracts = self._reorder_arrays(original_conn, new_weights,
interest_areas, new_tracts)
if new_tracts is None:
Expand All @@ -167,6 +175,9 @@ def _branch_connectivity(self, original_conn, new_weights, interest_areas,
for i in range(len(original_conn.weights)):
for j in range(len(original_conn.weights)):
if i not in interest_areas or j not in interest_areas:
if not selected:
new_weights[i][j] = 0
elif i in interest_areas or j in interest_areas and selected:
new_weights[i][j] = 0

final_conn = Connectivity()
Expand Down

0 comments on commit f9e5f7b

Please sign in to comment.