Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Behavior GUI #44

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 143 additions & 38 deletions src/gui/rawdata_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import iblphotometry.preprocessing as ffpr
import numpy as np
import pandas as pd


class DataFrameVisualizerApp(QWidget):
Expand All @@ -30,6 +31,8 @@ def __init__(self):

self.plot_time_index = None

self.behavior_gui = None

self.filtered_df = None # Filtered DataFrame used for plotting only
self.init_ui()

Expand All @@ -40,7 +43,7 @@ def init_ui(self):
# Layout for file loading and selection
file_layout = QHBoxLayout()
self.load_button = QPushButton('Load File', self)
self.load_button.clicked.connect(self.load_file)
self.load_button.clicked.connect(self.open_dialog)
file_layout.addWidget(self.load_button)

self.column_selector = QComboBox(self)
Expand All @@ -56,6 +59,10 @@ def init_ui(self):
self.filter_selector.currentIndexChanged.connect(self.apply_filter)
file_layout.addWidget(self.filter_selector)

self.behavior_button = QPushButton('Open behavior GUI', self)
self.behavior_button.clicked.connect(self.open_behavior_gui)
file_layout.addWidget(self.behavior_button)

# # Table widget to display DataFrame
# self.table = QTableWidget(self)
# self.table.setSelectionMode(QTableWidget.SingleSelection)
Expand Down Expand Up @@ -95,47 +102,50 @@ def init_ui(self):
self.setWindowTitle('DataFrame Plotter')
self.setGeometry(300, 100, 800, 600)

def load_file(self):
def load_file(self, file_path):
try:
if (
file_path.endswith('.csv')
or file_path.endswith('.pqt')
or file_path.endswith('.parquet')
):
self.dfs = from_raw_neurophotometrics_file(file_path)
else:
raise ValueError('Unsupported file format')

if 'GCaMP' in self.dfs.keys():
self.df = self.dfs['GCaMP']
self.times = self.dfs['GCaMP'].index.values
self.plot_time_index = np.arange(0, len(self.times))
self.filtered_df = None
else:
raise ValueError('No GCaMP found')

if 'Isosbestic' in self.dfs.keys():
self.dfiso = self.dfs['Isosbestic']

# Display the dataframe in the table
# self.display_dataframe()
# Update the column selector
self.update_column_selector()

# Load into Pynapple dataframe
self.dfs = from_raw_neurophotometrics_file(file_path)

# Set filter combo box
self.filter_selector.setCurrentIndex(0) # Reset to "Select Filter"

except Exception as e:
print(f'Error loading file: {e}')

def open_dialog(self):
# Open a file dialog to choose the CSV or PQT file
file_path, _ = QFileDialog.getOpenFileName(
self, 'Open File', '', 'CSV and PQT Files (*.csv *.pqt);;All Files (*)'
)
if file_path:
# Load the file into a DataFrame based on its extension
try:
if (
file_path.endswith('.csv')
or file_path.endswith('.pqt')
or file_path.endswith('.parquet')
):
self.dfs = from_raw_neurophotometrics_file(file_path)
else:
raise ValueError('Unsupported file format')

if 'GCaMP' in self.dfs.keys():
self.df = self.dfs['GCaMP']
self.times = self.dfs['GCaMP'].index.values
self.plot_time_index = np.arange(0, len(self.times))
self.filtered_df = None
else:
raise ValueError('No GCaMP found')

if 'Isosbestic' in self.dfs.keys():
self.dfiso = self.dfs['Isosbestic']

# Display the dataframe in the table
# self.display_dataframe()
# Update the column selector
self.update_column_selector()

# Load into Pynapple dataframe
self.dfs = from_raw_neurophotometrics_file(file_path)

# Set filter combo box
self.filter_selector.setCurrentIndex(0) # Reset to "Select Filter"

except Exception as e:
print(f'Error loading file: {e}')
self.load_file(file_path)

# TODO this does not work with pynapple as format, convert back to pandas DF
# def display_dataframe(self):
Expand Down Expand Up @@ -235,9 +245,9 @@ def on_column_header_clicked(self, logical_index):
# Update the plots based on the selected column
self.update_plots()

def apply_filter(self):
def apply_filter(self, filter_idx, filter_option=None):
# Get the selected filter option from the filter dropdown
filter_option = self.filter_selector.currentText()
filter_option = filter_option or self.filter_selector.currentText()

if filter_option == 'Select Filter':
self.filtered_df = None
Expand Down Expand Up @@ -282,9 +292,104 @@ def filter_mad(self, df):
# filtered_df[col] = filtered_df[col].apply(lambda x: x if x <= 100 else None)
# return filtered_df

def open_behavior_gui(self):
signal = self.plotobj.processed_signal

if self.behavior_gui is None:
self.behavior_gui = BehaviorVisualizerGUI()
assert self.behavior_gui is not None

if signal is None:
print('Apply a filter before opening the Behavior GUI')
else:
print('Opening Behavior GUI')
self.behavior_gui.set_data(signal, self.times)
self.behavior_gui.show()


class BehaviorVisualizerGUI(QWidget):
def __init__(
self,
):
super().__init__()
self.trials = None
self.init_ui()

def set_data(self, processed_signal, times):
assert processed_signal is not None
assert times is not None
self.processed_signal = processed_signal
self.times = times

def init_ui(self):
# Create layout
main_layout = QVBoxLayout()

# Layout for file loading and selection
file_layout = QHBoxLayout()
self.load_button = QPushButton('Load File', self)
self.load_button.clicked.connect(self.open_dialog)
file_layout.addWidget(self.load_button)

main_layout.addLayout(file_layout)

# Set up plots layout
self.plot_layout = QGridLayout()
self.plotobj = plots.PlotSignalResponse()
self.figure, self.axes = self.plotobj.set_fig_layout()
self.canvas = FigureCanvas(self.figure)
self.plot_layout.addWidget(self.canvas, 0, 0, 1, 3)

# Create a NavigationToolbar
self.toolbar = NavigationToolbar(self.canvas, self)

main_layout.addLayout(self.plot_layout)
self.setLayout(main_layout)

self.setWindowTitle('Behavior Visualizer')
self.setGeometry(300, 100, 800, 600)

def load_trials(self, trials):
assert trials is not None
self.trials = trials
self.update_plots()

def load_file(self, file_path):
# load a trial file
try:
if file_path.endswith('.pqt') or file_path.endswith('.parquet'):
self.load_trials(pd.read_parquet(file_path))
else:
raise ValueError('Unsupported file format')
except Exception as e:
print(f'Error loading file: {e}')

def open_dialog(self):
file_path, _ = QFileDialog.getOpenFileName(
self, 'Open File', '', 'CSV and PQT Files (*.csv *.pqt);;All Files (*)'
)
if file_path:
self.load_file(file_path)

def update_plots(self):
self.figure.clear()

self.plotobj.set_data(
self.trials,
self.processed_signal,
self.times,
)
# NOTE: we need to update the layout as it depends on the data
self.figure, self.axes = self.plotobj.set_fig_layout(figure=self.figure)
self.plotobj.plot_trialsort_psth(self.axes)

self.canvas.draw()


if __name__ == '__main__':
app = QApplication(sys.argv)
window = DataFrameVisualizerApp()
if len(sys.argv) >= 2:
window.load_file(sys.argv[1])
window.show()
sys.exit(app.exec_())
29 changes: 17 additions & 12 deletions src/iblphotometry/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ def set_axis_style(ax, fontsize=10, **kwargs):


class PlotSignal:
# def __init__(self, *args, **kwargs):
# self.set_data(*args, **kwargs)

def set_data(
self, raw_signal, times, raw_isosbestic=None, processed_signal=None, fs=None
):
Expand Down Expand Up @@ -143,7 +140,10 @@ def raw_processed_figure2(self, axd):


class PlotSignalResponse:
def __init__(
def __init__(self):
self.psth_dict = {}

def set_data(
self, trials, processed_signal, times, fs=None, event_window=np.array([-1, 2])
):
self.trials = trials
Expand Down Expand Up @@ -187,10 +187,20 @@ def update_psth_dict(self, event):
except KeyError:
warnings.warn(f'Event {event} not found in trials table.')

def plot_trialsort_psth(self):
fig, axs = plt.subplots(2, len(self.psth_dict.keys()) - 1)
def set_fig_layout(self, figure=None):
n = max(1, len(self.psth_dict.keys()) - 1)
if figure is None:
figure, axs = plt.subplots(2, n, squeeze=False)
else:
axs = figure.subplots(2, n, squeeze=False)
figure.tight_layout()
return figure, axs

def plot_trialsort_psth(self, axs):
signal_keys = [k for k in self.psth_dict.keys() if k != 'times']
if axs.shape[1] < len(signal_keys):
raise ValueError('Error, skipping PSTH plotting')

for iaxs, event in enumerate(signal_keys):
axs_plt = [axs[0, iaxs], axs[1, iaxs]]
plot_psth(self.psth_dict[event], self.psth_dict['times'], axs=axs_plt)
Expand All @@ -206,17 +216,12 @@ def plot_trialsort_psth(self):
if iaxs > 0:
axs[0, iaxs].axis('off')
axs[1, iaxs].set_yticks([])
fig.tight_layout()
return fig, axs

def plot_processed_trialtick(self, event_key='stimOn_times'):
fig, ax = plt.subplots(1, 1)
plt.figure(figsize=(10, 6))
def plot_processed_trialtick(self, ax, event_key='stimOn_times'):
events = self.trials[event_key]
ax.set_ylim([-0.2, 0.1])
plot_event_tick(events, ax=ax, color='#FFC0CB', ls='-')
plot_processed_signal(self.processed_signal, self.times, ax=ax)
return fig, ax


"""
Expand Down
36 changes: 33 additions & 3 deletions src/iblphotometry_tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import unittest
import pandas as pd
from pathlib import Path

Expand All @@ -6,6 +7,8 @@

from iblphotometry.behavior import psth, psth_times
import iblphotometry.plots as plots

# from gui.rawdata_visualizer import BehaviorVisualizerGUI
from iblphotometry.synthetic import synthetic101
import iblphotometry.preprocessing as ffpr
from iblphotometry_tests.base_tests import PhotometryDataTestCase
Expand Down Expand Up @@ -94,9 +97,13 @@ def test_class_plotsignalresponse(self):
# eid = '77a6741c-81cc-475f-9454-a9b997be02a4'
# trials = one.load_object(eid, 'trials')
trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt'])
plotobj = plots.PlotSignalResponse(trials, processed_signal, times)
plotobj.plot_trialsort_psth()
plotobj.plot_processed_trialtick()
plotobj = plots.PlotSignalResponse()
plotobj.set_data(trials, processed_signal, times)
_, axs = plotobj.set_fig_layout()
plotobj.plot_trialsort_psth(axs)
_, ax = plt.subplots(1, 1)
plotobj.plot_processed_trialtick(ax)
# plt.show()
plt.close('all')

"""
Expand Down Expand Up @@ -186,3 +193,26 @@ def test_plot_event_tick(self):
df_nph, t_events, fs = self.get_synthetic_data()
plots.plot_event_tick(t_events)
plt.close('all')

# def test_gui(self):
# df_nph, _, fs = self.get_test_data()
# processed_signal = df_nph['signal_processed'].values
# times = df_nph['times'].values
# trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt'])

# from PyQt5.QtWidgets import QApplication
# app = QApplication(sys.argv)
# window = BehaviorVisualizerGUI()
# window.set_data(processed_signal, times)
# window.load_trials(trials)
# window.show()
# # Uncomment to debug
# app.exec_()


if __name__ == '__main__':
unittest.main()
# suite = unittest.TestSuite()
# suite.addTest(TestPlotters())
# runner = unittest.TextTestRunner()
# runner.run(suite)
Loading