diff --git a/src/constantinople_lab_to_nwb/schierek_embargo_2024/interfaces/schierek_embargo_2024_processedbehaviorinterface.py b/src/constantinople_lab_to_nwb/schierek_embargo_2024/interfaces/schierek_embargo_2024_processedbehaviorinterface.py index 2ae51db..2513549 100644 --- a/src/constantinople_lab_to_nwb/schierek_embargo_2024/interfaces/schierek_embargo_2024_processedbehaviorinterface.py +++ b/src/constantinople_lab_to_nwb/schierek_embargo_2024/interfaces/schierek_embargo_2024_processedbehaviorinterface.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Optional, Union +from warnings import warn import numpy as np from ndx_structured_behavior.utils import loadmat @@ -113,16 +114,25 @@ def add_to_nwbfile( if nwbfile.trials is None: assert trial_start_times is not None, "'trial_start_times' must be provided if trials table is not added." assert trial_stop_times is not None, "'trial_stop_times' must be provided if trials table is not added." - assert ( - len(trial_start_times) == num_trials - ), f"Length of 'trial_start_times' ({len(trial_start_times)}) must match the number of trials ({num_trials})." - assert ( - len(trial_stop_times) == num_trials - ), f"Length of 'trial_stop_times' ({len(trial_stop_times)}) must match the number of trials ({num_trials})." else: trial_start_times = nwbfile.trials["start_time"][:] trial_stop_times = nwbfile.trials["stop_time"][:] + if len(trial_start_times) > num_trials: + warn( + f"The length of 'trial_start_times' ({len(trial_start_times)}) from Bpod doesn't match the number " + f"of trials ({num_trials}) in '{self.default_struct_name}' struct data." + ) + trial_start_times = trial_start_times[:num_trials] + trial_stop_times = trial_stop_times[:num_trials] + + assert ( + len(trial_start_times) == num_trials + ), f"Length of 'trial_start_times' ({len(trial_start_times)}) must match the number of trials ({num_trials})." + assert ( + len(trial_stop_times) == num_trials + ), f"Length of 'trial_stop_times' ({len(trial_stop_times)}) must match the number of trials ({num_trials})." + for start_time, stop_time in zip(trial_start_times, trial_stop_times): trials_table.add_row( start_time=start_time,