diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dedccb3..6e4ac88 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,9 +9,11 @@ repos: rev: 22.8.0 hooks: - id: black + args: [--line-length=79] exclude: ^docs/ -- repo: https://github.com/pycqa/isort - rev: 5.13.1 - hooks: - - id: isort - name: isort (python) +- repo: https://github.com/pycqa/isort + rev: 5.13.1 + hooks: + - id: isort + name: isort (python) + args: [--profile=black, --line-length=79] diff --git a/pyproject.toml b/pyproject.toml index 3cae529..b302147 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -line-length = 120 +line-length = 79 target-version = ['py38', 'py39', 'py310'] include = '\.pyi?$' extend-exclude = ''' diff --git a/requirements.txt b/requirements.txt index da271ed..ee88472 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -neuroconv==0.4.7 +neuroconv==0.4.6 spikeinterface==0.99.1 nwbwidgets==0.11.3 nwbinspector==0.4.31 diff --git a/src/jazayeri_lab_to_nwb/watters/display_interface.py b/src/jazayeri_lab_to_nwb/watters/display_interface.py index f5ecadd..dc1a716 100644 --- a/src/jazayeri_lab_to_nwb/watters/display_interface.py +++ b/src/jazayeri_lab_to_nwb/watters/display_interface.py @@ -1,4 +1,5 @@ """Class for converting data about display frames.""" + import itertools import json from pathlib import Path @@ -6,7 +7,9 @@ import numpy as np import pandas as pd -from neuroconv.datainterfaces.text.timeintervalsinterface import TimeIntervalsInterface +from neuroconv.datainterfaces.text.timeintervalsinterface import ( + TimeIntervalsInterface, +) from neuroconv.utils import FolderPathType from pynwb import NWBFile @@ -40,7 +43,9 @@ def get_metadata(self) -> dict: return metadata def get_timestamps(self) -> np.ndarray: - return super(DisplayInterface, self).get_timestamps(column="start_time") + return super(DisplayInterface, self).get_timestamps( + column="start_time" + ) def set_aligned_starting_time(self, aligned_starting_time: float) -> None: self.dataframe.start_time += aligned_starting_time @@ -49,15 +54,23 @@ def _read_file(self, file_path: FolderPathType): # Create dataframe with data for each frame trials = json.load(open(Path(file_path) / "trials.json", "r")) frames = { - k_mapped: list(itertools.chain(*[d[k] for d in trials])) for k, k_mapped in DisplayInterface.KEY_MAP.items() + k_mapped: list(itertools.chain(*[d[k] for d in trials])) + for k, k_mapped in DisplayInterface.KEY_MAP.items() } # Serialize object_positions data for hdf5 conversion to work - frames["object_positions"] = [json.dumps(x) for x in frames["object_positions"]] + frames["object_positions"] = [ + json.dumps(x) for x in frames["object_positions"] + ] return pd.DataFrame(frames) - def add_to_nwbfile(self, nwbfile: NWBFile, metadata: Optional[dict] = None, tag: str = "display"): + def add_to_nwbfile( + self, + nwbfile: NWBFile, + metadata: Optional[dict] = None, + tag: str = "display", + ): return super(DisplayInterface, self).add_to_nwbfile( nwbfile=nwbfile, metadata=metadata, diff --git a/src/jazayeri_lab_to_nwb/watters/get_session_paths.py b/src/jazayeri_lab_to_nwb/watters/get_session_paths.py index 7530364..6dd9bbb 100644 --- a/src/jazayeri_lab_to_nwb/watters/get_session_paths.py +++ b/src/jazayeri_lab_to_nwb/watters/get_session_paths.py @@ -26,28 +26,39 @@ def _get_session_paths_openmind(subject, session): subject_id = SUBJECT_NAME_TO_ID[subject] # Path to write output nwb files to - output_path = f"/om/user/nwatters/nwb_data_multi_prediction/staging/sub-{subject}" + output_path = ( + f"/om/user/nwatters/nwb_data_multi_prediction/staging/sub-{subject}" + ) # Path to the raw data. This is used for reading raw physiology data. - raw_data_path = f"/om4/group/jazlab/nwatters/multi_prediction/phys_data/{subject}/" f"{session}/raw_data" + raw_data_path = ( + f"/om4/group/jazlab/nwatters/multi_prediction/phys_data/{subject}/" + f"{session}/raw_data" + ) # Path to task and behavior data. task_behavior_data_path = ( - "/om4/group/jazlab/nwatters/multi_prediction/datasets/data_nwb_trials/" f"{subject}/{session}" + "/om4/group/jazlab/nwatters/multi_prediction/datasets/data_nwb_trials/" + f"{subject}/{session}" ) # Path to open-source data. This is used for reading behavior and task data. data_open_source_path = ( - "/om4/group/jazlab/nwatters/multi_prediction/datasets/data_open_source/" f"Subjects/{subject_id}/{session}/001" + "/om4/group/jazlab/nwatters/multi_prediction/datasets/data_open_source/" + f"Subjects/{subject_id}/{session}/001" ) # Path to sync pulses. This is used for reading timescale transformations # between physiology and mworks data streams. - sync_pulses_path = "/om4/group/jazlab/nwatters/multi_prediction/data_processed/" f"{subject}/{session}/sync_pulses" + sync_pulses_path = ( + "/om4/group/jazlab/nwatters/multi_prediction/data_processed/" + f"{subject}/{session}/sync_pulses" + ) # Path to spike sorting. This is used for reading spike sorted data. spike_sorting_raw_path = ( - f"/om4/group/jazlab/nwatters/multi_prediction/phys_data/{subject}/" f"{session}/spike_sorting" + f"/om4/group/jazlab/nwatters/multi_prediction/phys_data/{subject}/" + f"{session}/spike_sorting" ) session_paths = SessionPaths( diff --git a/src/jazayeri_lab_to_nwb/watters/main_convert_session.py b/src/jazayeri_lab_to_nwb/watters/main_convert_session.py index 41a85ab..aab00be 100644 --- a/src/jazayeri_lab_to_nwb/watters/main_convert_session.py +++ b/src/jazayeri_lab_to_nwb/watters/main_convert_session.py @@ -21,7 +21,6 @@ See comments below for descriptions of these variables. """ -import datetime import glob import json import logging @@ -89,7 +88,9 @@ def _add_v_probe_data( # Raw data recording_file = _get_single_file(probe_data_dir, suffix=".dat") - metadata_path = str(session_paths.data_open_source / "probes.metadata.json") + metadata_path = str( + session_paths.data_open_source / "probes.metadata.json" + ) raw_source_data[f"RecordingVP{probe_num}"] = dict( file_path=recording_file, probe_metadata_file=metadata_path, @@ -97,17 +98,28 @@ def _add_v_probe_data( probe_name=f"vprobe{probe_num}", es_key=f"ElectricalSeriesVP{probe_num}", ) - raw_conversion_options[f"RecordingVP{probe_num}"] = dict(stub_test=stub_test) + raw_conversion_options[f"RecordingVP{probe_num}"] = dict( + stub_test=stub_test + ) # Processed data - sorting_path = session_paths.spike_sorting_raw / f"v_probe_{probe_num}" / "ks_3_output_pre_v6_curated" - processed_source_data[f"RecordingVP{probe_num}"] = raw_source_data[f"RecordingVP{probe_num}"] + sorting_path = ( + session_paths.spike_sorting_raw + / f"v_probe_{probe_num}" + / "ks_3_output_pre_v6_curated" + ) + processed_source_data[f"RecordingVP{probe_num}"] = raw_source_data[ + f"RecordingVP{probe_num}" + ] processed_source_data[f"SortingVP{probe_num}"] = dict( - folder_path=str(sorting_path), - keep_good_only=False, + folder_path=str(sorting_path), keep_good_only=False + ) + processed_conversion_options[f"RecordingVP{probe_num}"] = dict( + stub_test=stub_test, write_electrical_series=False + ) + processed_conversion_options[f"SortingVP{probe_num}"] = dict( + stub_test=stub_test, write_as="processing" ) - processed_conversion_options[f"RecordingVP{probe_num}"] = dict(stub_test=stub_test, write_electrical_series=False) - processed_conversion_options[f"SortingVP{probe_num}"] = dict(stub_test=stub_test, write_as="processing") def _add_spikeglx_data( @@ -122,7 +134,11 @@ def _add_spikeglx_data( logging.info("Adding SpikeGLX data") # Raw data - spikeglx_dir = [x for x in (session_paths.raw_data / "spikeglx").iterdir() if "settling" not in str(x)] + spikeglx_dir = [ + x + for x in (session_paths.raw_data / "spikeglx").iterdir() + if "settling" not in str(x) + ] if len(spikeglx_dir) == 0: logging.info("Found no SpikeGLX data") elif len(spikeglx_dir) == 1: @@ -146,11 +162,17 @@ def _add_spikeglx_data( folder_path=str(sorting_path), keep_good_only=False, ) - processed_conversion_options["SortingNP"] = dict(stub_test=stub_test, write_as="processing") + processed_conversion_options["SortingNP"] = dict( + stub_test=stub_test, write_as="processing" + ) def session_to_nwb( - subject: str, session: str, stub_test: bool = False, overwrite: bool = True, dandiset_id: Union[str, None] = None + subject: str, + session: str, + stub_test: bool = False, + overwrite: bool = True, + dandiset_id: Union[str, None] = None, ): """ Convert a single session to an NWB file. @@ -190,7 +212,9 @@ def session_to_nwb( logging.info(f"dandiset_id = {dandiset_id}") # Get paths - session_paths = get_session_paths.get_session_paths(subject, session, repo=_REPO) + session_paths = get_session_paths.get_session_paths( + subject, session, repo=_REPO + ) logging.info(f"session_paths: {session_paths}") # Get paths for nwb files to write @@ -199,8 +223,13 @@ def session_to_nwb( session_id = f"{session}-stub" else: session_id = f"{session}" - raw_nwb_path = session_paths.output / f"sub-{subject}_ses-{session_id}_ecephys.nwb" - processed_nwb_path = session_paths.output / f"sub-{subject}_ses-{session_id}_behavior+ecephys.nwb" + raw_nwb_path = ( + session_paths.output / f"sub-{subject}_ses-{session_id}_ecephys.nwb" + ) + processed_nwb_path = ( + session_paths.output + / f"sub-{subject}_ses-{session_id}_behavior+ecephys.nwb" + ) logging.info(f"raw_nwb_path = {raw_nwb_path}") logging.info(f"processed_nwb_path = {processed_nwb_path}") logging.info("") @@ -247,12 +276,16 @@ def session_to_nwb( # Add trials data logging.info("Adding trials data") - processed_source_data["Trials"] = dict(folder_path=str(session_paths.task_behavior_data)) + processed_source_data["Trials"] = dict( + folder_path=str(session_paths.task_behavior_data) + ) processed_conversion_options["Trials"] = dict() # Add display data logging.info("Adding display data") - processed_source_data["Display"] = dict(folder_path=str(session_paths.task_behavior_data)) + processed_source_data["Display"] = dict( + folder_path=str(session_paths.task_behavior_data) + ) processed_conversion_options["Display"] = dict() # Create processed data converter @@ -269,10 +302,14 @@ def session_to_nwb( metadata["Subject"]["age"] = _SUBJECT_TO_AGE[subject] # EcePhys - probe_metadata_file = session_paths.data_open_source / "probes.metadata.json" + probe_metadata_file = ( + session_paths.data_open_source / "probes.metadata.json" + ) with open(probe_metadata_file, "r") as f: probe_metadata = json.load(f) - neuropixel_metadata = [x for x in probe_metadata if x["probe_type"] == "Neuropixels"][0] + neuropixel_metadata = [ + x for x in probe_metadata if x["probe_type"] == "Neuropixels" + ][0] for entry in metadata["Ecephys"]["ElectrodeGroup"]: if entry["device"] == "Neuropixel-Imec": # TODO: uncomment when fixed in pynwb @@ -291,10 +328,10 @@ def session_to_nwb( # Check if session_start_time was found/set if "session_start_time" not in metadata["NWBFile"]: - raise ValueError("Session start time was not auto-detected. Please provide it " "in `metadata.yaml`") - session_start_time = metadata["NWBFile"]["session_start_time"] - metadata["NWBFile"]["session_start_time"] = session_start_time.replace( - tzinfo=ZoneInfo("US/Eastern")) + raise ValueError( + "Session start time was not auto-detected. Please provide it " + "in `metadata.yaml`" + ) # Run conversion logging.info("Running processed conversion") diff --git a/src/jazayeri_lab_to_nwb/watters/nwb_converter.py b/src/jazayeri_lab_to_nwb/watters/nwb_converter.py index 5fa299f..488786d 100644 --- a/src/jazayeri_lab_to_nwb/watters/nwb_converter.py +++ b/src/jazayeri_lab_to_nwb/watters/nwb_converter.py @@ -1,30 +1,21 @@ """Primary NWBConverter class for this dataset.""" + import json import logging from pathlib import Path from typing import Optional +import display_interface import numpy as np -from display_interface import DisplayInterface -from neuroconv import NWBConverter -from neuroconv.datainterfaces import ( - KiloSortSortingInterface, - SpikeGLXRecordingInterface, -) +import timeseries_interface +import trials_interface +from neuroconv import NWBConverter, datainterfaces from neuroconv.datainterfaces.ecephys.basesortingextractorinterface import ( BaseSortingExtractorInterface, ) from neuroconv.utils import FolderPathType from recording_interface import DatRecordingInterface -from spikeinterface.core.waveform_tools import has_exceeding_spikes -from spikeinterface.curation import remove_excess_spikes -from timeseries_interface import ( - AudioInterface, - EyePositionInterface, - PupilSizeInterface, - RewardLineInterface, -) -from trials_interface import TrialsInterface +from spikeinterface.core import waveform_tools class NWBConverter(NWBConverter): @@ -32,27 +23,32 @@ class NWBConverter(NWBConverter): data_interface_classes = dict( RecordingVP0=DatRecordingInterface, - SortingVP0=KiloSortSortingInterface, + SortingVP0=datainterfaces.KiloSortSortingInterface, RecordingVP1=DatRecordingInterface, - SortingVP1=KiloSortSortingInterface, - RecordingNP=SpikeGLXRecordingInterface, - LF=SpikeGLXRecordingInterface, - SortingNP=KiloSortSortingInterface, - EyePosition=EyePositionInterface, - PupilSize=PupilSizeInterface, - RewardLine=RewardLineInterface, - Audio=AudioInterface, - Trials=TrialsInterface, - Display=DisplayInterface, + SortingVP1=datainterfaces.KiloSortSortingInterface, + RecordingNP=datainterfaces.SpikeGLXRecordingInterface, + LF=datainterfaces.SpikeGLXRecordingInterface, + SortingNP=datainterfaces.KiloSortSortingInterface, + EyePosition=timeseries_interface.EyePositionInterface, + PupilSize=timeseries_interface.PupilSizeInterface, + RewardLine=timeseries_interface.RewardLineInterface, + Audio=timeseries_interface.AudioInterface, + Trials=trials_interface.TrialsInterface, + Display=display_interface.DisplayInterface, ) - def __init__(self, source_data: dict[str, dict], sync_dir: Optional[FolderPathType] = None, verbose: bool = True): + def __init__( + self, + source_data: dict[str, dict], + sync_dir: Optional[FolderPathType] = None, + verbose: bool = True, + ): """Validate source_data and initialize all data interfaces.""" super().__init__(source_data=source_data, verbose=verbose) self.sync_dir = sync_dir unit_name_start = 0 - for name, data_interface in self.data_interface_objects.items(): + for data_interface in self.data_interface_objects.values(): if isinstance(data_interface, BaseSortingExtractorInterface): unit_ids = np.array(data_interface.sorting_extractor.unit_ids) data_interface.sorting_extractor.set_property( @@ -68,70 +64,69 @@ def temporally_align_data_interfaces(self): return sync_dir = Path(self.sync_dir) - # openephys alignment - with open(sync_dir / "open_ephys" / "recording_start_time") as f: - open_ephys_start_time = float(f.read().strip()) - with open(sync_dir / "open_ephys" / "transform", "r") as f: - open_ephys_transform = json.load(f) - for i in [0, 1]: - if f"RecordingVP{i}" in self.data_interface_objects: - orig_timestamps = self.data_interface_objects[f"RecordingVP{i}"].get_original_timestamps() - aligned_timestamps = open_ephys_transform["intercept"] + open_ephys_transform["coef"] * ( - open_ephys_start_time + orig_timestamps - ) - self.data_interface_objects[f"RecordingVP{i}"].set_aligned_timestamps(aligned_timestamps) - # openephys sorting alignment - if f"SortingVP{i}" in self.data_interface_objects: - if has_exceeding_spikes( - recording=self.data_interface_objects[f"RecordingVP{i}"].recording_extractor, - sorting=self.data_interface_objects[f"SortingVP{i}"].sorting_extractor, - ): - print( - f"Spikes exceeding recording found in SortingVP{i}! " - "Removing with `spikeinterface.curation.remove_excess_spikes()`" - ) - self.data_interface_objects[f"SortingVP{i}"].sorting_extractor = remove_excess_spikes( - recording=self.data_interface_objects[f"RecordingVP{i}"].recording_extractor, - sorting=self.data_interface_objects[f"SortingVP{i}"].sorting_extractor, - ) - self.data_interface_objects[f"SortingVP{i}"].register_recording( - self.data_interface_objects[f"RecordingVP{i}"] - ) + # Align each recording + for name, recording_interface in self.data_interface_objects.items(): + if "Recording" not in name: + continue + probe_name = name.split("Recording")[1] - # neuropixel alignment - orig_timestamps = self.data_interface_objects["RecordingNP"].get_original_timestamps() - with open(sync_dir / "spikeglx" / "transform", "r") as f: - spikeglx_transform = json.load(f) - aligned_timestamps = spikeglx_transform["intercept"] + spikeglx_transform["coef"] * orig_timestamps - self.data_interface_objects["RecordingNP"].set_aligned_timestamps(aligned_timestamps) - # neuropixel LFP alignment - orig_timestamps = self.data_interface_objects["LF"].get_original_timestamps() - aligned_timestamps = spikeglx_transform["intercept"] + spikeglx_transform["coef"] * orig_timestamps - self.data_interface_objects["LF"].set_aligned_timestamps(aligned_timestamps) - # neuropixel sorting alignment - if "SortingNP" in self.data_interface_objects: - if has_exceeding_spikes( - recording=self.data_interface_objects["RecordingNP"].recording_extractor, - sorting=self.data_interface_objects["SortingNP"].sorting_extractor, - ): - print( - "Spikes exceeding recording found in SortingNP! " - "Removing with `spikeinterface.curation.remove_excess_spikes()`" + # Load timescale transform + if "VP" in probe_name: + start_path = sync_dir / "open_ephys" / "recording_start_time" + start = float(open(start_path).read().strip()) + transform_path = sync_dir / "open_ephys" / "transform" + transform = json.load(open(transform_path, "r")) + lf_interface = None + elif "NP" in probe_name: + start = 0.0 + transform_path = sync_dir / "spikeglx" / "transform" + transform = json.load(open(transform_path, "r")) + lf_interface = self.data_interface_objects["LF"] + intercept = transform["intercept"] + coef = transform["coef"] + + # Align recording timestamps + orig_timestamps = recording_interface.get_original_timestamps() + aligned_timestamps = intercept + coef * (start + orig_timestamps) + recording_interface.set_aligned_timestamps(aligned_timestamps) + + # Align LFP timestamps + if lf_interface is not None: + orig_timestamps = lf_interface.get_original_timestamps() + aligned_timestamps = intercept + coef * ( + start + orig_timestamps ) - self.data_interface_objects["SortingNP"].sorting_extractor = remove_excess_spikes( - recording=self.data_interface_objects["RecordingNP"].recording_extractor, - sorting=self.data_interface_objects["SortingNP"].sorting_extractor, + lf_interface.set_aligned_timestamps(aligned_timestamps) + + # If sorting exists, register recording to it + if f"Sorting{probe_name}" in self.data_interface_objects: + sorting_interface = self.data_interface_objects[ + f"Sorting{probe_name}" + ] + + # Sanity check no sorted spikes are outside recording range + exceeded_spikes = waveform_tools.has_exceeding_spikes( + recording=recording_interface.recording_extractor, + sorting=sorting_interface.sorting_extractor, ) - self.data_interface_objects["SortingNP"].register_recording(self.data_interface_objects["RecordingNP"]) + if exceeded_spikes: + raise ValueError( + f"Spikes exceeding recording found in Sorting{probe_name}!" + ) + + # Register recording + sorting_interface.register_recording(recording_interface) - # align recording start to 0 + # Align so that 0 is the first of all timestamps aligned_start_times = [] - for name, data_interface in self.data_interface_objects.items(): + for data_interface in self.data_interface_objects.values(): start_time = data_interface.get_timestamps()[0] aligned_start_times.append(start_time) zero_time = -1.0 * min(aligned_start_times) - for name, data_interface in self.data_interface_objects.items(): + for data_interface in self.data_interface_objects.values(): if isinstance(data_interface, BaseSortingExtractorInterface): # Do not need to align because recording will be aligned continue - start_time = data_interface.set_aligned_starting_time(aligned_starting_time=zero_time) + start_time = data_interface.set_aligned_starting_time( + aligned_starting_time=zero_time + ) diff --git a/src/jazayeri_lab_to_nwb/watters/recording_interface.py b/src/jazayeri_lab_to_nwb/watters/recording_interface.py index 89523d2..5ee3619 100644 --- a/src/jazayeri_lab_to_nwb/watters/recording_interface.py +++ b/src/jazayeri_lab_to_nwb/watters/recording_interface.py @@ -1,5 +1,5 @@ """Primary class for recording data.""" -import json + from typing import Optional import numpy as np @@ -8,7 +8,6 @@ BaseRecordingExtractorInterface, ) from neuroconv.utils import FilePathType -from spikeinterface import BaseRecording class DatRecordingInterface(BaseRecordingExtractorInterface): @@ -30,6 +29,8 @@ def __init__( probe_name: str = "vprobe", probe_key: Optional[str] = None, ): + del probe_metadata_file + del probe_key source_data = { "file_paths": [file_path], "sampling_frequency": sampling_frequency, @@ -45,15 +46,11 @@ def __init__( # this is used for metadata naming self.probe_name = probe_name - # add probe information - with open(probe_metadata_file, "r") as f: - all_probe_metadata = json.load(f) - for entry in all_probe_metadata: - if entry["label"] == probe_key: - probe_metadata = entry - - # Generate V-probe geometry: 64 channels arranged vertically with 50 um spacing - probe = probeinterface.generate_linear_probe(num_elec=channel_count, ypitch=50) + # Generate V-probe geometry: 64 channels arranged vertically with 50 um + # spacing + probe = probeinterface.generate_linear_probe( + num_elec=channel_count, ypitch=50 + ) probe.set_device_channel_indices(np.arange(channel_count)) probe.name = probe_name @@ -75,10 +72,11 @@ def get_metadata(self) -> dict: manufacturer="Plexon", ) ] + description = f"a group representing electrodes on {self.probe_name}" electrode_groups = [ dict( name=self.probe_name, - description=f"a group representing electrodes on {self.probe_name}", + description=description, location="unknown", device=self.probe_name, ) diff --git a/src/jazayeri_lab_to_nwb/watters/timeseries_interface.py b/src/jazayeri_lab_to_nwb/watters/timeseries_interface.py index ca0f83c..2c9ffc3 100644 --- a/src/jazayeri_lab_to_nwb/watters/timeseries_interface.py +++ b/src/jazayeri_lab_to_nwb/watters/timeseries_interface.py @@ -5,13 +5,16 @@ For trial structured variables, see ../trials_interface.py. For variables pertaining to display updates, see ../frames_interface.py. """ + import json from pathlib import Path import numpy as np from hdmf.backends.hdf5 import H5DataIO from ndx_events import LabeledEvents -from neuroconv.basetemporalalignmentinterface import BaseTemporalAlignmentInterface +from neuroconv.basetemporalalignmentinterface import ( + BaseTemporalAlignmentInterface, +) from neuroconv.tools.nwb_helpers import get_module from neuroconv.utils import FolderPathType from pynwb import NWBFile, TimeSeries @@ -61,10 +64,15 @@ def __init__(self, folder_path: FolderPathType): # Check eye_h and eye_v have the same number of samples if len(eye_h_times) != len(eye_v_times): - raise ValueError(f"len(eye_h_times) = {len(eye_h_times)}, but len(eye_v_times) " f"= {len(eye_v_times)}") + raise ValueError( + f"len(eye_h_times) = {len(eye_h_times)}, but len(eye_v_times) " + f"= {len(eye_v_times)}" + ) # Check that eye_h_times and eye_v_times are similar to within 0.5ms if not np.allclose(eye_h_times, eye_v_times, atol=0.0005): - raise ValueError("eye_h_times and eye_v_times are not sufficiently similar") + raise ValueError( + "eye_h_times and eye_v_times are not sufficiently similar" + ) # Set data attributes self.set_original_timestamps(eye_h_times) @@ -72,7 +80,7 @@ def __init__(self, folder_path: FolderPathType): def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict): del metadata - + # Make SpatialSeries eye_position = SpatialSeries( name="eye_position", @@ -85,8 +93,12 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict): ) # Get processing module - module_description = "Contains behavior, audio, and reward data from experiment." - processing_module = get_module(nwbfile=nwbfile, name="behavior", description=module_description) + module_description = ( + "Contains behavior, audio, and reward data from experiment." + ) + processing_module = get_module( + nwbfile=nwbfile, name="behavior", description=module_description + ) # Add data to module processing_module.add_data_interface(eye_position) @@ -122,8 +134,12 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict): ) # Get processing module - module_description = "Contains behavior, audio, and reward data from experiment." - processing_module = get_module(nwbfile=nwbfile, name="behavior", description=module_description) + module_description = ( + "Contains behavior, audio, and reward data from experiment." + ) + processing_module = get_module( + nwbfile=nwbfile, name="behavior", description=module_description + ) # Add data to module processing_module.add_data_interface(pupil_size) @@ -151,15 +167,21 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict): # Make LabeledEvents reward_line = LabeledEvents( name="reward_line", - description=("Reward line data representing events of reward dispenser"), + description=( + "Reward line data representing events of reward dispenser" + ), timestamps=H5DataIO(self._timestamps, compression="gzip"), data=self._reward_line, labels=["closed", "open"], ) # Get processing module - module_description = "Contains behavior, audio, and reward data from experiment." - processing_module = get_module(nwbfile=nwbfile, name="behavior", description=module_description) + module_description = ( + "Contains behavior, audio, and reward data from experiment." + ) + processing_module = get_module( + nwbfile=nwbfile, name="behavior", description=module_description + ) # Add data to module processing_module.add_data_interface(reward_line) @@ -188,7 +210,7 @@ def __init__(self, folder_path: FolderPathType): def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict): del metadata - + # Make LabeledEvents audio = LabeledEvents( name="audio", @@ -199,8 +221,12 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict): ) # Get processing module - module_description = "Contains behavior, audio, and reward data from experiment." - processing_module = get_module(nwbfile=nwbfile, name="behavior", description=module_description) + module_description = ( + "Contains behavior, audio, and reward data from experiment." + ) + processing_module = get_module( + nwbfile=nwbfile, name="behavior", description=module_description + ) # Add data to module processing_module.add_data_interface(audio) diff --git a/src/jazayeri_lab_to_nwb/watters/trials_interface.py b/src/jazayeri_lab_to_nwb/watters/trials_interface.py index c030938..d32b07d 100644 --- a/src/jazayeri_lab_to_nwb/watters/trials_interface.py +++ b/src/jazayeri_lab_to_nwb/watters/trials_interface.py @@ -1,11 +1,14 @@ """Class for converting trial-structured data.""" + import json from pathlib import Path from typing import Optional import numpy as np import pandas as pd -from neuroconv.datainterfaces.text.timeintervalsinterface import TimeIntervalsInterface +from neuroconv.datainterfaces.text.timeintervalsinterface import ( + TimeIntervalsInterface, +) from neuroconv.utils import FolderPathType from pynwb import NWBFile @@ -73,12 +76,16 @@ def set_aligned_starting_time(self, aligned_starting_time: float) -> None: def _read_file(self, file_path: FolderPathType): # Create dataframe with data for each trial trials = json.load(open(Path(file_path) / "trials.json", "r")) - trials = {k_mapped: [d[k] for d in trials] for k, k_mapped in TrialsInterface.KEY_MAP.items()} + trials = { + k_mapped: [d[k] for d in trials] + for k, k_mapped in TrialsInterface.KEY_MAP.items() + } # Field closed_loop_response_position may have None values, so replace # those with NaN to make hdf5 conversion work trials["closed_loop_response_position"] = [ - [np.nan, np.nan] if x is None else x for x in trials["closed_loop_response_position"] + [np.nan, np.nan] if x is None else x + for x in trials["closed_loop_response_position"] ] # Serialize fields with variable-length lists for hdf5 conversion @@ -92,7 +99,12 @@ def _read_file(self, file_path: FolderPathType): return pd.DataFrame(trials) - def add_to_nwbfile(self, nwbfile: NWBFile, metadata: Optional[dict] = None, tag: str = "trials"): + def add_to_nwbfile( + self, + nwbfile: NWBFile, + metadata: Optional[dict] = None, + tag: str = "trials", + ): return super(TrialsInterface, self).add_to_nwbfile( nwbfile=nwbfile, metadata=metadata, @@ -103,8 +115,14 @@ def add_to_nwbfile(self, nwbfile: NWBFile, metadata: Optional[dict] = None, tag: @property def column_descriptions(self): column_descriptions = { - "background_indices": ("For each trial, the indices of the background noise pattern " "patch."), - "broke_fixation": ("For each trial, whether the subject broke fixation and the " "trial was aborted"), + "background_indices": ( + "For each trial, the indices of the background noise pattern " + "patch." + ), + "broke_fixation": ( + "For each trial, whether the subject broke fixation and the " + "trial was aborted" + ), "stimulus_object_identities": ( "For each trial, a serialized list with one element for each " 'object. Each element is the identity symbol (e.g. "a", "b", ' @@ -141,13 +159,21 @@ def column_descriptions(self): "reward delivery." ), "start_time": "Start time of each trial.", - "phase_fixation_time": ("Time of fixation phase onset for each trial."), - "phase_stimulus_time": ("Time of stimulus phase onset for each trial."), + "phase_fixation_time": ( + "Time of fixation phase onset for each trial." + ), + "phase_stimulus_time": ( + "Time of stimulus phase onset for each trial." + ), "phase_delay_time": "Time of delay phase onset for each trial.", "phase_cue_time": "Time of cue phase onset for each trial.", - "phase_response_time": ("Time of response phase onset for each trial."), + "phase_response_time": ( + "Time of response phase onset for each trial." + ), "phase_reveal_time": "Time of reveal phase onset for each trial.", - "phase_iti_time": ("Time of inter-trial interval onset for each trial."), + "phase_iti_time": ( + "Time of inter-trial interval onset for each trial." + ), "reward_time": "Time of reward delivery onset for each trial.", "reward_duration": "Reward duration for each trial", "response_position": (