Skip to content

Commit

Permalink
WID-223: sync ui with data and some refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbacter01 committed Sep 14, 2023
1 parent ab46fa1 commit bbbc7a9
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 54 deletions.
21 changes: 17 additions & 4 deletions tvbwidgets/ui/connectivity_ipy/connectivity_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tvbwidgets.ui.connectivity_ipy.outputs_3d import PyVistaOutput
from tvbwidgets.ui.connectivity_ipy.operations import ConnectivityOperations
from tvbwidgets.ui.connectivity_ipy.config import ConnectivityConfig
from tvbwidgets.ui.connectivity_ipy.global_context import CONTEXT
from tvbwidgets.ui.connectivity_ipy.global_context import CONTEXT, ObservableAttrs

DROPDOWN_KEY = 'dropdown'

Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self, **kwargs):

self.__draw_connectivity()
self.__show_plot()
CONTEXT.observe(lambda *args: self.__show_plot(), 'connectivity')
CONTEXT.observe(lambda *args: self.__show_plot(), ObservableAttrs.CONNECTIVITY)

def add_datatype(self, datatype): # type: (HasTraits) -> None
pass
Expand Down Expand Up @@ -99,7 +99,7 @@ def on_change(change):
def on_ctx_change(value):
dropdown.value = value

CONTEXT.observe(on_ctx_change, 'matrix')
CONTEXT.observe(on_ctx_change, ObservableAttrs.MATRIX)
self.widgets_map[DROPDOWN_KEY] = dropdown
self.children = (dropdown, *self.children)

Expand All @@ -112,7 +112,7 @@ def __init__(self, **kwargs):
super(Connectivity3DViewer, self).__init__([self.output], *kwargs)

self.init_view_connectivity()
CONTEXT.observe(lambda *args: self.init_view_connectivity(), 'connectivity')
CONTEXT.observe(lambda *args: self.init_view_connectivity(), ObservableAttrs.CONNECTIVITY)

def init_view_connectivity(self):
self.output.plotter.clear()
Expand Down Expand Up @@ -226,6 +226,19 @@ def add_datatype(self, datatype):
"""
pass

def get_connectivity(self, gid=None):
# type: (str|None) -> Connectivity
"""
Get a connectivity with the gid provided from the context history.
if gid=None return the current connectivity set on CONTEXT
"""
if gid is None:
return CONTEXT.connectivity
conn = list(filter(lambda c: c.gid.hex == gid, CONTEXT.connectivities_history))
if not len(conn):
return None
return conn[0]

def __init__(self, connectivity, **kwargs):
style = self.DEFAULT_BORDER
super().__init__(**kwargs, layout=style)
Expand Down
17 changes: 14 additions & 3 deletions tvbwidgets/ui/connectivity_ipy/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@
#
# (c) 2022-2023, TVB Widgets Team
#
import enum
from typing import Callable
from tvb.datatypes.connectivity import Connectivity


class ObservableAttrs(str, enum.Enum):
"""
Enum representing observable attributes of the GlobalContext singleton
"""
MATRIX = 'matrix'
CONNECTIVITY = 'connectivity'


class SingletonMeta(type):
_instances = {}

Expand Down Expand Up @@ -35,7 +44,7 @@ def matrix(self, next_value):
prev_value = self.__matrix
self.__matrix = next_value
if prev_value != next_value:
self.__notify_observers('matrix', next_value)
self.__notify_observers(ObservableAttrs.MATRIX, next_value)

@property
def connectivity(self):
Expand All @@ -47,10 +56,12 @@ def connectivity(self, next_value):
# type: (Connectivity) -> None
previous = self.__connectivity
self.__connectivity = next_value
if previous != next_value:
self.__notify_observers('connectivity', next_value)
if not len(self.connectivities_history):
self.connectivities_history = [self.__connectivity]
if previous and previous.gid.hex != next_value.gid.hex:
if not any([conn.gid == next_value.gid for conn in self.connectivities_history]):
self.connectivities_history.append(next_value)
self.__notify_observers(ObservableAttrs.CONNECTIVITY, next_value)

def __notify_observers(self, observed_attribute, next_value):
# type: (str, any) -> None
Expand Down
82 changes: 35 additions & 47 deletions tvbwidgets/ui/connectivity_ipy/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@ def __init__(self, **kwargs):
def update_history_ui(*args):
children[1] = self.get_history_dropdown()
self.children = children
self.send_state(self.keys) # trigger ui update

def update_selector_ui(*args):
children[2] = self.__get_node_selector()
self.children = children
self.send_state(self.keys) # trigger ui update

CONTEXT.observe(update_history_ui, 'connectivity')
CONTEXT.observe(update_selector_ui, 'connectivity')
self.children = children

@property
Expand Down Expand Up @@ -73,24 +80,8 @@ def __get_node_selector(self):
)),
layout={'width': '50%', 'align-items': 'start'}
)
matrix_dropdown = ipywidgets.Dropdown(
options=CONTEXT.MATRIX_OPTIONS,
value=CONTEXT.matrix,
description='Matrix:'
)

def on_change(value):
CONTEXT.matrix = value['new']

matrix_dropdown.observe(on_change, 'value')

def on_ctx_change(value):
matrix_dropdown.value = value

CONTEXT.observe(on_ctx_change, 'matrix')

container = ipywidgets.VBox(children=(matrix_dropdown,
ipywidgets.HBox(children=(left, right))))
container = ipywidgets.HBox(children=(left, right))
accordion = ipywidgets.Accordion(children=[container], selected_index=None,
layout={'max-height': '50vh'})
accordion.set_title(0, 'Regions selector')
Expand Down Expand Up @@ -133,22 +124,15 @@ def __cut_selected_nodes(self):
"""
Create a new connectivity using only the selected nodes
"""
matrix = CONTEXT.matrix
new_weights = CONTEXT.connectivity.weights
new_tracts = CONTEXT.connectivity.tract_lengths if matrix == 'tracts' else None
regions = CONTEXT.connectivity.region_labels
selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions]
new_conn = self._cut_connectivity(CONTEXT.connectivity, new_weights, selected_regions, new_tracts)
new_conn = self._cut_connectivity(selected_regions)
CONTEXT.connectivity = new_conn

def __cut_edges(self, selected=False):
matrix = CONTEXT.matrix
new_weights = CONTEXT.connectivity.weights
new_tracts = CONTEXT.connectivity.tract_lengths if matrix == 'tracts' else None
regions = CONTEXT.connectivity.region_labels
selected_regions = [numpy.where(regions == label)[0][0] for label in self.selected_regions]
new_conn = self._cut_connectivity_edges(CONTEXT.connectivity, new_weights, selected_regions, new_tracts,
selected)
new_conn = self._cut_connectivity_edges(selected_regions, selected)
CONTEXT.connectivity = new_conn

def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts=None):
Expand All @@ -168,33 +152,38 @@ def _reorder_arrays(self, original_conn, new_weights, interest_areas, new_tracts

def get_history_dropdown(self):
values = [(conn.gid.hex, conn) for conn in CONTEXT.connectivities_history]
default = len(values) and values[-1][1] or None

dropdown = ipywidgets.Dropdown(options=values,
description='View history',
disabled=False)
dropdown.observe(lambda x: print(x['new']), 'value')
disabled=False,
value=default)

def on_connectivity_change(change):
CONTEXT.connectivity = change['new']

dropdown.observe(on_connectivity_change, 'value')
return dropdown

def _cut_connectivity_edges(self, original_conn, new_weights, interest_areas,
new_tracts=None, selected=False):
# type: (Connectivity, numpy.array, numpy.array, numpy.array) -> Connectivity
def _cut_connectivity_edges(self, interest_areas, selected=False):
# type: (numpy.array, bool) -> Connectivity
"""
Generate new Connectivity based on a previous one, by changing weights (e.g. simulate lesion).
The returned connectivity has the same number of nodes. The edges of unselected nodes will have weight 0.
:param original_conn: Original Connectivity, to copy from
:param new_weights: weights matrix for the new connectivity
:param interest_areas: ndarray of the selected node id's
:param new_tracts: tracts matrix for the new connectivity
:param selected: if true cuts out edges of selected areas else unselected edges
"""

original_conn = CONTEXT.connectivity
new_weights = CONTEXT.connectivity.weights
new_tracts = CONTEXT.connectivity.tract_lengths

if not len(interest_areas):
LOGGER.error('No intrest areas selected!')
LOGGER.error('No interest areas selected!')
return CONTEXT.connectivity

new_weights, interest_areas, new_tracts = self._reorder_arrays(original_conn, new_weights,
interest_areas, new_tracts)
if new_tracts is None:
new_tracts = original_conn.tract_lengths

for i in range(len(original_conn.weights)):
for j in range(len(original_conn.weights)):
Expand All @@ -218,28 +207,27 @@ def _cut_connectivity_edges(self, original_conn, new_weights, interest_areas,
final_conn.configure()
return final_conn

def _cut_connectivity(self, original_conn, new_weights, interest_areas, new_tracts=None):
# type: (Connectivity, numpy.array, numpy.array, numpy.array) -> Connectivity
def _cut_connectivity(self, interest_areas, selected=False):
# type: (numpy.array, bool) -> Connectivity
"""
Generate new Connectivity object based on current one, by removing nodes (e.g. simulate lesion).
Only the selected nodes will get used in the result. The order of the indices in interest_areas matters.
If indices are not sorted then the nodes will be permuted accordingly.
:param original_conn: Original Connectivity(HasTraits), to cut nodes from
:param new_weights: weights matrix for the new connectivity
:param interest_areas: ndarray with the selected node id's.
:param new_tracts: tracts matrix for the new connectivity
"""

original_conn = CONTEXT.connectivity
new_weights = CONTEXT.connectivity.weights
new_tracts = CONTEXT.connectivity.tract_lengths

if not len(interest_areas):
LOGGER.error('No interest areas selected!')
return CONTEXT.connectivity

new_weights, interest_areas, new_tracts = self._reorder_arrays(original_conn, new_weights,
interest_areas, new_tracts)
if new_tracts is None:
new_tracts = original_conn.tract_lengths[interest_areas, :][:, interest_areas]
else:
new_tracts = new_tracts[interest_areas, :][:, interest_areas]

new_tracts = new_tracts[interest_areas, :][:, interest_areas]
new_weights = new_weights[interest_areas, :][:, interest_areas]

final_conn = Connectivity()
Expand Down

0 comments on commit bbbc7a9

Please sign in to comment.