diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index 88874196..851324ec 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,2 +1,3 @@ # Normalize line endings -89ea7df05f77475bcfd874f3bccab878d653af6a \ No newline at end of file +89ea7df05f77475bcfd874f3bccab878d653af6a +947eb6c1f701050a03d319feee168260f2a485a0 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c2d9d2a3..1610386c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: python: python3.11 -files: "^(docker|aeon\/dj_pipeline)\/.*$" +files: "^(test|aeon(?!\/dj_pipeline\/).*)$" repos: - repo: meta hooks: @@ -30,7 +30,7 @@ repos: hooks: - id: black args: [--check, --config, ./pyproject.toml] - + - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.0.286 hooks: diff --git a/aeon/analysis/movies.py b/aeon/analysis/movies.py index 3092105b..f71a0c3e 100644 --- a/aeon/analysis/movies.py +++ b/aeon/analysis/movies.py @@ -1,12 +1,14 @@ -import cv2 import math + +import cv2 import numpy as np import pandas as pd -import aeon.io.video as video + +from aeon.io import video + def gridframes(frames, width, height, shape=None): - ''' - Arranges a set of frames into a grid layout with the specified + """Arranges a set of frames into a grid layout with the specified pixel dimensions and shape. :param list frames: A list of frames to include in the grid layout. @@ -16,15 +18,15 @@ def gridframes(frames, width, height, shape=None): Either the number of frames to include, or the number of rows and columns in the output grid image layout. :return: A new image containing the arrangement of the frames in a grid. - ''' + """ if shape is None: shape = len(frames) - if type(shape) not in [list,tuple]: + if type(shape) not in [list, tuple]: shape = math.ceil(math.sqrt(shape)) shape = (shape, shape) dsize = (height, width, 3) - cellsize = (height // shape[0], width // shape[1],3) + cellsize = (height // shape[0], width // shape[1], 3) grid = np.zeros(dsize, dtype=np.uint8) for i in range(shape[0]): for j in range(shape[1]): @@ -39,19 +41,20 @@ def gridframes(frames, width, height, shape=None): grid[i0:i1, j0:j1] = cv2.resize(frame, (cellsize[1], cellsize[0])) return grid + def averageframes(frames): """Returns the average of the specified collection of frames.""" return cv2.convertScaleAbs(sum(np.multiply(1 / len(frames), frames))) + def groupframes(frames, n, fun): - ''' - Applies the specified function to each group of n-frames. + """Applies the specified function to each group of n-frames. :param iterable frames: A sequence of frames to process. :param int n: The number of frames in each group. :param callable fun: The function used to process each group of frames. :return: An iterable returning the results of applying the function to each group. - ''' + """ i = 0 group = [] for frame in frames: @@ -61,9 +64,9 @@ def groupframes(frames, n, fun): group.clear() i = i + 1 + def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)): - ''' - Split video data around the specified sequence of event timestamps. + """Split video data around the specified sequence of event timestamps. :param DataFrame data: A pandas DataFrame where each row specifies video acquisition path and frame number. @@ -72,7 +75,7 @@ def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)): :param Timedelta after: The right offset from each timestamp used to clip the data. :return: A pandas DataFrame containing the frames, clip and sequence numbers for each event timestamp. - ''' + """ if before is not pd.Timedelta: before = pd.Timedelta(before) if after is not pd.Timedelta: @@ -81,30 +84,30 @@ def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)): events = events.index clips = [] - for i,index in enumerate(events): - clip = data.loc[(index-before):(index+after)].copy() - clip['frame_sequence'] = list(range(len(clip))) - clip['clip_sequence'] = i + for i, index in enumerate(events): + clip = data.loc[(index - before) : (index + after)].copy() + clip["frame_sequence"] = list(range(len(clip))) + clip["clip_sequence"] = i clips.append(clip) return pd.concat(clips) + def collatemovie(clipdata, fun): - ''' - Collates a set of video clips into a single movie using the specified aggregation function. + """Collates a set of video clips into a single movie using the specified aggregation function. :param DataFrame clipdata: A pandas DataFrame where each row specifies video path, frame number, clip and sequence number. This DataFrame can be obtained from the output of the triggerclip function. :param callable fun: The aggregation function used to process the frames in each clip. :return: The sequence of processed frames representing the collated movie. - ''' - clipcount = len(clipdata.groupby('clip_sequence').frame_sequence.count()) - allframes = video.frames(clipdata.sort_values(by=['frame_sequence', 'clip_sequence'])) + """ + clipcount = len(clipdata.groupby("clip_sequence").frame_sequence.count()) + allframes = video.frames(clipdata.sort_values(by=["frame_sequence", "clip_sequence"])) return groupframes(allframes, clipcount, fun) + def gridmovie(clipdata, width, height, shape=None): - ''' - Collates a set of video clips into a grid movie with the specified pixel dimensions + """Collates a set of video clips into a grid movie with the specified pixel dimensions and grid layout. :param DataFrame clipdata: @@ -116,5 +119,5 @@ def gridmovie(clipdata, width, height, shape=None): Either the number of frames to include, or the number of rows and columns in the output grid movie layout. :return: The sequence of processed frames representing the collated grid movie. - ''' - return collatemovie(clipdata, lambda g:gridframes(g, width, height, shape)) \ No newline at end of file + """ + return collatemovie(clipdata, lambda g: gridframes(g, width, height, shape)) diff --git a/aeon/analysis/plotting.py b/aeon/analysis/plotting.py index cf05d075..f89d78f5 100644 --- a/aeon/analysis/plotting.py +++ b/aeon/analysis/plotting.py @@ -1,17 +1,16 @@ import math -import matplotlib.colors as colors import matplotlib.pyplot as plt import numpy as np import pandas as pd +from matplotlib import colors from matplotlib.collections import LineCollection from aeon.analysis.utils import * def heatmap(position, frequency, ax=None, **kwargs): - """ - Draw a heatmap of time spent in each location from specified position data and sampling frequency. + """Draw a heatmap of time spent in each location from specified position data and sampling frequency. :param Series position: A series of position data containing x and y coordinates. :param number frequency: The sampling frequency for the position data. @@ -33,8 +32,7 @@ def heatmap(position, frequency, ax=None, **kwargs): def circle(x, y, radius, fmt=None, ax=None, **kwargs): - """ - Plot a circle centered at the given x, y position with the specified radius. + """Plot a circle centered at the given x, y position with the specified radius. :param number x: The x-component of the circle center. :param number y: The y-component of the circle center. @@ -62,8 +60,7 @@ def rateplot( ax=None, **kwargs, ): - """ - Plot the continuous event rate and raster of a discrete event sequence, given the specified + """Plot the continuous event rate and raster of a discrete event sequence, given the specified window size and sampling frequency. :param Series events: The discrete sequence of events. @@ -79,9 +76,7 @@ def rateplot( :param Axes, optional ax: The Axes on which to draw the rate plot and raster. """ label = kwargs.pop("label", None) - eventrate = rate( - events, window, frequency, weight, start, end, smooth=smooth, center=center - ) + eventrate = rate(events, window, frequency, weight, start, end, smooth=smooth, center=center) if ax is None: ax = plt.gca() ax.plot( @@ -90,14 +85,11 @@ def rateplot( label=label, **kwargs, ) - ax.vlines( - sessiontime(events.index, eventrate.index[0]), -0.2, -0.1, linewidth=1, **kwargs - ) + ax.vlines(sessiontime(events.index, eventrate.index[0]), -0.2, -0.1, linewidth=1, **kwargs) def set_ymargin(ax, bottom, top): - """ - Set the vertical margins of the specified Axes. + """Set the vertical margins of the specified Axes. :param Axes ax: The Axes for which to specify the vertical margin. :param number bottom: The size of the bottom margin. @@ -121,8 +113,7 @@ def colorline( ax=None, **kwargs, ): - """ - Plot a dynamically colored line on the specified Axes. + """Plot a dynamically colored line on the specified Axes. :param array-like x, y: The horizontal / vertical coordinates of the data points. :param array-like, optional z: diff --git a/aeon/analysis/utils.py b/aeon/analysis/utils.py index d443f542..eb738106 100644 --- a/aeon/analysis/utils.py +++ b/aeon/analysis/utils.py @@ -1,15 +1,15 @@ import numpy as np import pandas as pd + def distancetravelled(angle, radius=4.0): - ''' - Calculates the total distance travelled on the wheel, by taking into account + """Calculates the total distance travelled on the wheel, by taking into account its radius and the total number of turns in both directions across time. :param Series angle: A series of magnetic encoder measurements. :param float radius: The radius of the wheel, in metric units. :return: The total distance travelled on the wheel, in metric units. - ''' + """ maxvalue = int(np.iinfo(np.uint16).max >> 2) jumpthreshold = maxvalue // 2 turns = angle.astype(int).diff() @@ -20,9 +20,9 @@ def distancetravelled(angle, radius=4.0): distance = distance - distance[0] return distance -def visits(data, onset='Enter', offset='Exit'): - ''' - Computes duration, onset and offset times from paired events. Allows for missing data + +def visits(data, onset="Enter", offset="Exit"): + """Computes duration, onset and offset times from paired events. Allows for missing data by trying to match event onset times with subsequent offset times. If the match fails, event offset metadata is filled with NaN. Any additional metadata columns in the data frame will be paired and included in the output. @@ -31,45 +31,45 @@ def visits(data, onset='Enter', offset='Exit'): :param str, optional onset: The label used to identify event onsets. :param str, optional offset: The label used to identify event offsets. :return: A pandas data frame containing duration and metadata for each visit. - ''' + """ lonset = onset.lower() loffset = offset.lower() - lsuffix = '_{0}'.format(lonset) - rsuffix = '_{0}'.format(loffset) - id_onset = 'id' + lsuffix - event_onset = 'event' + lsuffix - event_offset = 'event' + rsuffix - time_onset = 'time' + lsuffix - time_offset = 'time' + rsuffix + lsuffix = f"_{lonset}" + rsuffix = f"_{loffset}" + id_onset = "id" + lsuffix + event_onset = "event" + lsuffix + event_offset = "event" + rsuffix + time_onset = "time" + lsuffix + time_offset = "time" + rsuffix # find all possible onset / offset pairs data = data.reset_index() data_onset = data[data.event == onset] data_offset = data[data.event == offset] - data = pd.merge(data_onset, data_offset, on='id', how='left', suffixes=[lsuffix, rsuffix]) + data = pd.merge(data_onset, data_offset, on="id", how="left", suffixes=[lsuffix, rsuffix]) # valid pairings have the smallest positive duration - data['duration'] = data[time_offset] - data[time_onset] + data["duration"] = data[time_offset] - data[time_onset] valid_visits = data[data.duration >= pd.Timedelta(0)] - data = data.iloc[valid_visits.groupby([time_onset, 'id']).duration.idxmin()] + data = data.iloc[valid_visits.groupby([time_onset, "id"]).duration.idxmin()] data = data[data.duration > pd.Timedelta(0)] # duplicate offsets indicate missing data from previous pairing - missing_data = data.duplicated(subset=time_offset, keep='last') + missing_data = data.duplicated(subset=time_offset, keep="last") if missing_data.any(): - data.loc[missing_data, ['duration'] + [name for name in data.columns if rsuffix in name]] = pd.NA + data.loc[missing_data, ["duration"] + [name for name in data.columns if rsuffix in name]] = pd.NA # rename columns and sort data - data.rename({ time_onset:lonset, id_onset:'id', time_offset:loffset}, axis=1, inplace=True) - data = data[['id'] + [name for name in data.columns if '_' in name] + [lonset, loffset, 'duration']] + data.rename({time_onset: lonset, id_onset: "id", time_offset: loffset}, axis=1, inplace=True) + data = data[["id"] + [name for name in data.columns if "_" in name] + [lonset, loffset, "duration"]] data.drop([event_onset, event_offset], axis=1, inplace=True) data.sort_index(inplace=True) data.reset_index(drop=True, inplace=True) return data + def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None, center=False): - ''' - Computes the continuous event rate from a discrete event sequence, given the specified + """Computes the continuous event rate from a discrete event sequence, given the specified window size and sampling frequency. :param Series events: The discrete sequence of events. @@ -83,22 +83,25 @@ def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None, The size of the smoothing kernel applied to the continuous rate output. :param bool, optional center: Specifies whether to center the convolution kernels. :return: A Series containing the continuous event rate over time. - ''' + """ counts = pd.Series(weight, events.index) if start is not None and start < events.index[0]: counts.loc[start] = 0 if end is not None and end > events.index[-1]: counts.loc[end] = 0 counts.sort_index(inplace=True) - counts = counts.resample(pd.Timedelta(1 / frequency, 's')).sum() - rate = counts.rolling(window,center=center).sum() - return rate.rolling(window if smooth is None else smooth,center=center).mean() + counts = counts.resample(pd.Timedelta(1 / frequency, "s")).sum() + rate = counts.rolling(window, center=center).sum() + return rate.rolling(window if smooth is None else smooth, center=center).mean() + -def get_events_rates(events, window_len_sec, frequency, unit_len_sec=60, start=None, end=None, smooth=None, center=False): +def get_events_rates( + events, window_len_sec, frequency, unit_len_sec=60, start=None, end=None, smooth=None, center=False +): # events is an array with the time (in seconds) of event occurence # window_len_sec is the size of the window over which the event rate is estimated # unit_len_sec is the length of one sample point - window_len_sec_str = "{:d}S".format(window_len_sec) + window_len_sec_str = f"{window_len_sec:d}S" counts = pd.Series(1.0, events.index) if start is not None and start < events.index[0]: counts.loc[start] = 0 @@ -106,29 +109,35 @@ def get_events_rates(events, window_len_sec, frequency, unit_len_sec=60, start=N counts.loc[end] = 0 counts.sort_index(inplace=True) counts_resampled = counts.resample(frequency).sum() - counts_rolled = counts_resampled.rolling(window_len_sec_str,center=center).sum()*unit_len_sec/window_len_sec - counts_rolled_smoothed = counts_rolled.rolling(window_len_sec_str if smooth is None else smooth, center=center).mean() + counts_rolled = ( + counts_resampled.rolling(window_len_sec_str, center=center).sum() * unit_len_sec / window_len_sec + ) + counts_rolled_smoothed = counts_rolled.rolling( + window_len_sec_str if smooth is None else smooth, center=center + ).mean() return counts_rolled_smoothed + def sessiontime(index, start=None): """Converts absolute to relative time, with optional reference starting time.""" - if (start is None): + if start is None: start = index[0] - return (index-start).total_seconds() / 60 + return (index - start).total_seconds() / 60 + def distance(position, target): """Computes the euclidean distance to a specified target.""" - return np.sqrt(np.square(position[['x','y']] - target).sum(axis=1)) + return np.sqrt(np.square(position[["x", "y"]] - target).sum(axis=1)) + def activepatch(wheel, in_patch): - ''' - Computes a decision boundary for when a patch is active based on wheel movement. - + """Computes a decision boundary for when a patch is active based on wheel movement. + :param Series wheel: A pandas Series containing the cumulative distance travelled on the wheel. :param Series in_patch: A Series of type bool containing whether the specified patch may be active. :return: A pandas Series specifying for each timepoint whether the patch is active. - ''' + """ exit_patch = in_patch.astype(np.int8).diff() < 0 - in_wheel = (wheel.diff().rolling('1s').sum() > 1).reindex(in_patch.index, method='pad') + in_wheel = (wheel.diff().rolling("1s").sum() > 1).reindex(in_patch.index, method="pad") epochs = exit_patch.cumsum() - return in_wheel.groupby(epochs).apply(lambda x:x.cumsum()) > 0 \ No newline at end of file + return in_wheel.groupby(epochs).apply(lambda x: x.cumsum()) > 0 diff --git a/aeon/io/__init__.py b/aeon/io/__init__.py index 4287ca86..792d6005 100644 --- a/aeon/io/__init__.py +++ b/aeon/io/__init__.py @@ -1 +1 @@ -# \ No newline at end of file +# diff --git a/aeon/io/api.py b/aeon/io/api.py index 88f64953..2b9bc745 100644 --- a/aeon/io/api.py +++ b/aeon/io/api.py @@ -1,40 +1,43 @@ import bisect import datetime -import pandas as pd -from pathlib import Path from os import PathLike +from pathlib import Path + +import pandas as pd """The duration of each acquisition chunk, in whole hours.""" CHUNK_DURATION = 1 + def aeon(seconds): """Converts a Harp timestamp, in seconds, to a datetime object.""" - return datetime.datetime(1904, 1, 1) + pd.to_timedelta(seconds, 's') + return datetime.datetime(1904, 1, 1) + pd.to_timedelta(seconds, "s") + def chunk(time): - ''' - Returns the whole hour acquisition chunk for a measurement timestamp. - + """Returns the whole hour acquisition chunk for a measurement timestamp. + :param datetime or Series time: An object or series specifying the measurement timestamps. :return: A datetime object or series specifying the acquisition chunk for the measurement timestamp. - ''' + """ if isinstance(time, pd.Series): hour = CHUNK_DURATION * (time.dt.hour // CHUNK_DURATION) - return pd.to_datetime(time.dt.date) + pd.to_timedelta(hour, 'h') + return pd.to_datetime(time.dt.date) + pd.to_timedelta(hour, "h") else: hour = CHUNK_DURATION * (time.hour // CHUNK_DURATION) return pd.to_datetime(datetime.datetime.combine(time.date(), datetime.time(hour=hour))) + def chunk_range(start, end): - ''' - Returns a range of whole hour acquisition chunks. + """Returns a range of whole hour acquisition chunks. :param datetime start: The left bound of the time range. :param datetime end: The right bound of the time range. :return: A DatetimeIndex representing the acquisition chunk range. - ''' + """ return pd.date_range(chunk(start), chunk(end), freq=pd.DateOffset(hours=CHUNK_DURATION)) + def chunk_key(file): """Returns the acquisition chunk key for the specified file name.""" epoch = file.parts[-3] @@ -46,17 +49,19 @@ def chunk_key(file): date_str, time_str = epoch.split("T") return epoch, datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) + def _set_index(data): if not isinstance(data.index, pd.DatetimeIndex): data.index = aeon(data.index) - data.index.name = 'time' + data.index.name = "time" + def _empty(columns): - return pd.DataFrame(columns=columns, index=pd.DatetimeIndex([], name='time')) + return pd.DataFrame(columns=columns, index=pd.DatetimeIndex([], name="time")) + def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=None): - ''' - Extracts chunk data from the root path of an Aeon dataset using the specified data stream + """Extracts chunk data from the root path of an Aeon dataset using the specified data stream reader. A subset of the data can be loaded by specifying an optional time range, or a list of timestamps used to index the data on file. Returned data will be sorted chronologically. @@ -69,7 +74,7 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No The maximum distance between original and new timestamps for inexact matches. :param str, optional epoch: A wildcard pattern to use when searching epoch data. :return: A pandas data frame containing epoch event metadata, sorted by time. - ''' + """ if isinstance(root, str): root = Path(root) if isinstance(root, PathLike): @@ -77,9 +82,10 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No epoch_pattern = "**" if epoch is None else epoch fileset = { - chunk_key(fname):fname + chunk_key(fname): fname for path in root - for fname in path.glob(f"{epoch_pattern}/**/{reader.pattern}.{reader.extension}")} + for fname in path.glob(f"{epoch_pattern}/**/{reader.pattern}.{reader.extension}") + } files = sorted(fileset.items()) if time is not None: @@ -93,7 +99,7 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No dataframes = [] filetimes = [chunk for (_, chunk), _ in files] files = [file for _, file in files] - for key,values in time.groupby(by=chunk): + for key, values in time.groupby(by=chunk): i = bisect.bisect_left(filetimes, key) if i < len(filetimes): frame = reader.read(files[i]) @@ -101,22 +107,22 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No else: frame = _empty(reader.columns) data = frame.reset_index() - data.set_index('time', drop=False, inplace=True) - data = data.reindex(values, method='pad', tolerance=tolerance) + data.set_index("time", drop=False, inplace=True) + data = data.reindex(values, method="pad", tolerance=tolerance) missing = len(data.time) - data.time.count() if missing > 0 and i > 0: # expand reindex to allow adjacent chunks # to fill missing values - previous = reader.read(files[i-1]) + previous = reader.read(files[i - 1]) data = pd.concat([previous, frame]) - data = data.reindex(values, method='pad', tolerance=tolerance) + data = data.reindex(values, method="pad", tolerance=tolerance) else: - data.drop(columns='time', inplace=True) + data.drop(columns="time", inplace=True) dataframes.append(data) if len(dataframes) == 0: return _empty(reader.columns) - + return pd.concat(dataframes) if start is not None or end is not None: @@ -134,11 +140,12 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No return data.loc[start:end] except KeyError: import warnings + if not data.index.has_duplicates: - warnings.warn('data index for {0} contains out-of-order timestamps!'.format(reader.pattern)) + warnings.warn(f"data index for {reader.pattern} contains out-of-order timestamps!") data = data.sort_index() else: - warnings.warn('data index for {0} contains duplicate keys!'.format(reader.pattern)) - data = data[~data.index.duplicated(keep='first')] + warnings.warn(f"data index for {reader.pattern} contains duplicate keys!") + data = data[~data.index.duplicated(keep="first")] return data.loc[start:end] return data diff --git a/aeon/io/device.py b/aeon/io/device.py index c889e5eb..1a4916e6 100644 --- a/aeon/io/device.py +++ b/aeon/io/device.py @@ -1,5 +1,6 @@ import inspect + def compositeStream(pattern, *args): """Merges multiple data streams into a single composite stream.""" composite = {} @@ -15,8 +16,7 @@ def compositeStream(pattern, *args): class Device: - """ - Groups multiple data streams into a logical device. + """Groups multiple data streams into a logical device. If a device contains a single stream with the same pattern as the device `name`, it will be considered a singleton, and the stream reader will be diff --git a/aeon/io/reader.py b/aeon/io/reader.py index a69c4169..67570608 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -1,34 +1,38 @@ -import os -import math -import json import datetime +import json +import math +import os + import numpy as np import pandas as pd -from aeon.io.api import chunk_key from dotmap import DotMap +from aeon.io.api import chunk_key + _SECONDS_PER_TICK = 32e-6 _payloadtypes = { - 1 : np.dtype(np.uint8), - 2 : np.dtype(np.uint16), - 4 : np.dtype(np.uint32), - 8 : np.dtype(np.uint64), - 129 : np.dtype(np.int8), - 130 : np.dtype(np.int16), - 132 : np.dtype(np.int32), - 136 : np.dtype(np.int64), - 68 : np.dtype(np.float32) + 1: np.dtype(np.uint8), + 2: np.dtype(np.uint16), + 4: np.dtype(np.uint32), + 8: np.dtype(np.uint64), + 129: np.dtype(np.int8), + 130: np.dtype(np.int16), + 132: np.dtype(np.int32), + 136: np.dtype(np.int64), + 68: np.dtype(np.float32), } + class Reader: """Extracts data from raw files in an Aeon dataset. - + Attributes: pattern (str): Pattern used to find raw files, usually in the format `_`. columns (str or array-like): Column labels to use for the data. extension (str): Extension of data file pathnames. """ + def __init__(self, pattern, columns, extension): self.pattern = pattern self.columns = columns @@ -38,8 +42,10 @@ def read(self, _): """Reads data from the specified file.""" return pd.DataFrame(columns=self.columns, index=pd.DatetimeIndex([])) + class Harp(Reader): """Extracts data from raw binary files encoded using the Harp protocol.""" + def __init__(self, pattern, columns, extension="bin"): super().__init__(pattern, columns, extension) @@ -59,36 +65,38 @@ def read(self, file): ticks = np.ndarray(length, dtype=np.uint16, buffer=data, offset=9, strides=stride) seconds = ticks * _SECONDS_PER_TICK + seconds payload = np.ndarray( - payloadshape, - dtype=payloadtype, - buffer=data, offset=11, - strides=(stride, elementsize)) + payloadshape, dtype=payloadtype, buffer=data, offset=11, strides=(stride, elementsize) + ) if self.columns is not None and payloadshape[1] < len(self.columns): - data = pd.DataFrame(payload, index=seconds, columns=self.columns[:payloadshape[1]]) - data[self.columns[payloadshape[1]:]] = math.nan + data = pd.DataFrame(payload, index=seconds, columns=self.columns[: payloadshape[1]]) + data[self.columns[payloadshape[1] :]] = math.nan return data else: return pd.DataFrame(payload, index=seconds, columns=self.columns) + class Chunk(Reader): """Extracts path and epoch information from chunk files in the dataset.""" + def __init__(self, reader=None, pattern=None, extension=None): if isinstance(reader, Reader): pattern = reader.pattern extension = reader.extension - super().__init__(pattern, columns=['path', 'epoch'], extension=extension) - + super().__init__(pattern, columns=["path", "epoch"], extension=extension) + def read(self, file): """Returns path and epoch information for the specified chunk.""" epoch, chunk = chunk_key(file) - data = { 'path': file, 'epoch': epoch } + data = {"path": file, "epoch": epoch} return pd.DataFrame(data, index=[chunk], columns=self.columns) + class Metadata(Reader): """Extracts metadata information from all epochs in the dataset.""" + def __init__(self, pattern="Metadata"): - super().__init__(pattern, columns=['workflow', 'commit', 'metadata'], extension="yml") + super().__init__(pattern, columns=["workflow", "commit", "metadata"], extension="yml") def read(self, file): """Returns metadata for the specified epoch.""" @@ -97,32 +105,28 @@ def read(self, file): time = datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) with open(file) as fp: metadata = json.load(fp) - workflow = metadata.pop('Workflow') - commit = metadata.pop('Commit', pd.NA) - data = { 'workflow': workflow, 'commit': commit, 'metadata': [DotMap(metadata)] } + workflow = metadata.pop("Workflow") + commit = metadata.pop("Commit", pd.NA) + data = {"workflow": workflow, "commit": commit, "metadata": [DotMap(metadata)]} return pd.DataFrame(data, index=[time], columns=self.columns) + class Csv(Reader): - """ - Extracts data from comma-separated (csv) text files, where the first column + """Extracts data from comma-separated (csv) text files, where the first column stores the Aeon timestamp, in seconds. """ + def __init__(self, pattern, columns, dtype=None, extension="csv"): super().__init__(pattern, columns, extension) self.dtype = dtype def read(self, file): """Reads data from the specified CSV text file.""" - return pd.read_csv( - file, - header=0, - names=self.columns, - dtype=self.dtype, - index_col=0) + return pd.read_csv(file, header=0, names=self.columns, dtype=self.dtype, index_col=0) + class Subject(Csv): - """ - Extracts metadata for subjects entering and exiting the environment. + """Extracts metadata for subjects entering and exiting the environment. Columns: id (str): Unique identifier of a subject in the environment. @@ -130,46 +134,50 @@ class Subject(Csv): or exiting the environment. event (str): Event type. Can be one of `Enter`, `Exit` or `Remain`. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['id', 'weight', 'event']) + super().__init__(pattern, columns=["id", "weight", "event"]) + class Log(Csv): - """ - Extracts message log data. - + """Extracts message log data. + Columns: priority (str): Priority level of the message. type (str): Type of the log message. message (str): Log message data. Can be structured using tab separated values. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['priority', 'type', 'message']) + super().__init__(pattern, columns=["priority", "type", "message"]) + class Heartbeat(Harp): - """ - Extract periodic heartbeat event data. - + """Extract periodic heartbeat event data. + Columns: second (int): The whole second corresponding to the heartbeat, in seconds. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['second']) + super().__init__(pattern, columns=["second"]) + class Encoder(Harp): - """ - Extract magnetic encoder data. - + """Extract magnetic encoder data. + Columns: angle (float): Absolute angular position, in radians, of the magnetic encoder. intensity (float): Intensity of the magnetic field. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['angle', 'intensity']) + super().__init__(pattern, columns=["angle", "intensity"]) + class Position(Harp): - """ - Extract 2D position tracking data for a specific camera. + """Extract 2D position tracking data for a specific camera. Columns: x (float): x-coordinate of the object center of mass. @@ -182,89 +190,89 @@ class Position(Harp): area (float): number of pixels in the object mass. id (float): unique tracking ID of the object in a frame. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['x', 'y', 'angle', 'major', 'minor', 'area', 'id']) + super().__init__(pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"]) + class BitmaskEvent(Harp): - """ - Extracts event data matching a specific digital I/O bitmask. + """Extracts event data matching a specific digital I/O bitmask. Columns: event (str): Unique identifier for the event code. """ + def __init__(self, pattern, value, tag): - super().__init__(pattern, columns=['event']) + super().__init__(pattern, columns=["event"]) self.value = value self.tag = tag def read(self, file): - """ - Reads a specific event code from digital data and matches it to the + """Reads a specific event code from digital data and matches it to the specified unique identifier. """ data = super().read(file) data = data[data.event & self.value > 0] - data['event'] = self.tag + data["event"] = self.tag return data + class DigitalBitmask(Harp): - """ - Extracts event data matching a specific digital I/O bitmask. + """Extracts event data matching a specific digital I/O bitmask. Columns: event (str): Unique identifier for the event code. """ + def __init__(self, pattern, mask, columns): super().__init__(pattern, columns) self.mask = mask def read(self, file): - """ - Reads a specific event code from digital data and matches it to the + """Reads a specific event code from digital data and matches it to the specified unique identifier. """ data = super().read(file) state = data[self.columns] & self.mask return state[(state.diff() != 0).values] != 0 + class Video(Csv): - """ - Extracts video frame metadata. - + """Extracts video frame metadata. + Columns: hw_counter (int): Hardware frame counter value for the current frame. hw_timestamp (int): Internal camera timestamp for the current frame. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['hw_counter', 'hw_timestamp', '_frame', '_path', '_epoch']) - self._rawcolumns = ['time'] + self.columns[0:2] + super().__init__(pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"]) + self._rawcolumns = ["time"] + self.columns[0:2] def read(self, file): """Reads video metadata from the specified file.""" data = pd.read_csv(file, header=0, names=self._rawcolumns) - data['_frame'] = data.index - data['_path'] = os.path.splitext(file)[0] + '.avi' - data['_epoch'] = file.parts[-3] - data.set_index('time', inplace=True) + data["_frame"] = data.index + data["_path"] = os.path.splitext(file)[0] + ".avi" + data["_epoch"] = file.parts[-3] + data.set_index("time", inplace=True) return data + def from_dict(data, pattern=None): - reader_type = data.get('type', None) + reader_type = data.get("type", None) if reader_type is not None: - kwargs = {k:v for k,v in data.items() if k != 'type'} + kwargs = {k: v for k, v in data.items() if k != "type"} return globals()[reader_type](pattern=pattern, **kwargs) - return DotMap({ - k:from_dict(v, f"{pattern}_{k}" if pattern is not None else k) - for k,v in data.items() - }) + return DotMap( + {k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) for k, v in data.items()} + ) + def to_dict(dotmap): if isinstance(dotmap, Reader): - kwargs = { k:v for k,v in vars(dotmap).items() - if k not in ['pattern'] and not k.startswith('_') } - kwargs['type'] = type(dotmap).__name__ + kwargs = {k: v for k, v in vars(dotmap).items() if k not in ["pattern"] and not k.startswith("_")} + kwargs["type"] = type(dotmap).__name__ return kwargs - return { - k:to_dict(v) for k,v in dotmap.items() - } \ No newline at end of file + return {k: to_dict(v) for k, v in dotmap.items()} diff --git a/aeon/io/video.py b/aeon/io/video.py index 18a70d9c..79e43daa 100644 --- a/aeon/io/video.py +++ b/aeon/io/video.py @@ -1,15 +1,15 @@ import cv2 + def frames(data): - ''' - Extracts the raw frames corresponding to the provided video metadata. + """Extracts the raw frames corresponding to the provided video metadata. :param DataFrame data: A pandas DataFrame where each row specifies video acquisition path and frame number. :return: An object to iterate over numpy arrays for each row in the DataFrame, containing the raw video frame data. - ''' + """ capture = None filename = None index = 0 @@ -27,29 +27,29 @@ def frames(data): index = frameidx success, frame = capture.read() if not success: - raise ValueError('Unable to read frame {0} from video path "{1}".'.format(frameidx, path)) + raise ValueError(f'Unable to read frame {frameidx} from video path "{path}".') yield frame index = index + 1 finally: if capture is not None: capture.release() + def export(frames, file, fps, fourcc=None): - ''' - Exports the specified frame sequence to a new video file. + """Exports the specified frame sequence to a new video file. :param iterable frames: An object to iterate over the raw video frame data. :param str file: The path to the exported video file. :param fps: The frame rate of the exported video. :param optional fourcc: Specifies the four character code of the codec used to compress the frames. - ''' + """ writer = None try: for frame in frames: if writer is None: if fourcc is None: - fourcc = cv2.VideoWriter_fourcc('m','p','4','v') + fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") writer = cv2.VideoWriter(file, fourcc, fps, (frame.shape[1], frame.shape[0])) writer.write(frame) finally: diff --git a/aeon/qc/video.py b/aeon/qc/video.py index 1773a3b8..1e090dc9 100644 --- a/aeon/qc/video.py +++ b/aeon/qc/video.py @@ -1,38 +1,48 @@ import os -import aeon.io.api as aeon from pathlib import Path -root = '/ceph/aeon/test2/experiment0.1' -qcroot = '/ceph/aeon/aeon/qc/experiment0.1' -devicenames = ['FrameEast','FrameGate','FrameNorth','FramePatch1','FramePatch2','FrameSouth','FrameTop','FrameWest'] +import aeon.io.api as aeon + +root = "/ceph/aeon/test2/experiment0.1" +qcroot = "/ceph/aeon/aeon/qc/experiment0.1" +devicenames = [ + "FrameEast", + "FrameGate", + "FrameNorth", + "FramePatch1", + "FramePatch2", + "FrameSouth", + "FrameTop", + "FrameWest", +] for device in devicenames: videochunks = aeon.chunkdata(root, device) - videochunks['epoch'] = videochunks.path.str.rsplit('/', n=3, expand=True)[1] + videochunks["epoch"] = videochunks.path.str.rsplit("/", n=3, expand=True)[1] stats = [] frameshifts = [] - for (key, period) in videochunks.groupby(by='epoch'): + for key, period in videochunks.groupby(by="epoch"): frame_offset = 0 path = Path(os.path.join(qcroot, key, device)) path.mkdir(parents=True, exist_ok=True) for chunk in period.itertuples(): - outpath = Path(chunk.path.replace(root, qcroot)).with_suffix('.parquet') - print('[{1}] Analysing {0} {2}... '.format(device, key,chunk.Index),end="") + outpath = Path(chunk.path.replace(root, qcroot)).with_suffix(".parquet") + print(f"[{key}] Analysing {device} {chunk.Index}... ", end="") data = aeon.videoreader(chunk.path).reset_index() deltas = data[data.columns[0:4]].diff() - deltas.columns = [ 'time_delta', 'frame_delta', 'hw_counter_delta', 'hw_timestamp_delta'] - deltas['frame_offset'] = (deltas.hw_counter_delta - 1).cumsum() + frame_offset + deltas.columns = ["time_delta", "frame_delta", "hw_counter_delta", "hw_timestamp_delta"] + deltas["frame_offset"] = (deltas.hw_counter_delta - 1).cumsum() + frame_offset drop_count = deltas.frame_offset.iloc[-1] max_harp_delta = deltas.time_delta.max().total_seconds() - max_camera_delta = deltas.hw_timestamp_delta.max() / 1e9 # convert nanoseconds to seconds - print('drops: {0} frameOffset: {1} maxHarpDelta: {2} s maxCameraDelta: {3} s'.format( - drop_count - frame_offset, - drop_count, - max_harp_delta, - max_camera_delta)) + max_camera_delta = deltas.hw_timestamp_delta.max() / 1e9 # convert nanoseconds to seconds + print( + "drops: {} frameOffset: {} maxHarpDelta: {} s maxCameraDelta: {} s".format( + drop_count - frame_offset, drop_count, max_harp_delta, max_camera_delta + ) + ) stats.append((drop_count, max_harp_delta, max_camera_delta, chunk.path)) deltas.set_index(data.time, inplace=True) deltas.to_parquet(outpath) frameshifts.append(deltas) - frame_offset = drop_count \ No newline at end of file + frame_offset = drop_count diff --git a/aeon/schema/core.py b/aeon/schema/core.py index d12beb98..8181c710 100644 --- a/aeon/schema/core.py +++ b/aeon/schema/core.py @@ -1,38 +1,47 @@ -import aeon.io.reader as _reader import aeon.io.device as _device +import aeon.io.reader as _reader + def heartbeat(pattern): """Heartbeat event for Harp devices.""" - return { "Heartbeat": _reader.Heartbeat(f"{pattern}_8_*") } + return {"Heartbeat": _reader.Heartbeat(f"{pattern}_8_*")} + def video(pattern): """Video frame metadata.""" - return { "Video": _reader.Video(f"{pattern}_*") } + return {"Video": _reader.Video(f"{pattern}_*")} + def position(pattern): """Position tracking data for the specified camera.""" - return { "Position": _reader.Position(f"{pattern}_200_*") } + return {"Position": _reader.Position(f"{pattern}_200_*")} + def encoder(pattern): """Wheel magnetic encoder data.""" - return { "Encoder": _reader.Encoder(f"{pattern}_90_*") } + return {"Encoder": _reader.Encoder(f"{pattern}_90_*")} + def environment(pattern): """Metadata for environment mode and subjects.""" return _device.compositeStream(pattern, environment_state, subject_state) + def environment_state(pattern): """Environment state log.""" - return { "EnvironmentState": _reader.Csv(f"{pattern}_EnvironmentState_*", ['state']) } + return {"EnvironmentState": _reader.Csv(f"{pattern}_EnvironmentState_*", ["state"])} + def subject_state(pattern): """Subject state log.""" - return { "SubjectState": _reader.Subject(f"{pattern}_SubjectState_*") } + return {"SubjectState": _reader.Subject(f"{pattern}_SubjectState_*")} + def messageLog(pattern): """Message log data.""" - return { "MessageLog": _reader.Log(f"{pattern}_MessageLog_*") } + return {"MessageLog": _reader.Log(f"{pattern}_MessageLog_*")} + def metadata(pattern): """Metadata for acquisition epochs.""" - return { pattern: _reader.Metadata(pattern) } + return {pattern: _reader.Metadata(pattern)} diff --git a/aeon/schema/dataset.py b/aeon/schema/dataset.py index 52cec571..b9586de4 100644 --- a/aeon/schema/dataset.py +++ b/aeon/schema/dataset.py @@ -1,53 +1,59 @@ from dotmap import DotMap + import aeon.schema.core as stream -import aeon.schema.foraging as foraging -import aeon.schema.octagon as octagon from aeon.io.device import Device +from aeon.schema import foraging, octagon -exp02 = DotMap([ - Device("Metadata", stream.metadata), - Device("ExperimentalMetadata", stream.environment, stream.messageLog), - Device("CameraTop", stream.video, stream.position, foraging.region), - Device("CameraEast", stream.video), - Device("CameraNest", stream.video), - Device("CameraNorth", stream.video), - Device("CameraPatch1", stream.video), - Device("CameraPatch2", stream.video), - Device("CameraSouth", stream.video), - Device("CameraWest", stream.video), - Device("Nest", foraging.weight), - Device("Patch1", foraging.patch), - Device("Patch2", foraging.patch) -]) +exp02 = DotMap( + [ + Device("Metadata", stream.metadata), + Device("ExperimentalMetadata", stream.environment, stream.messageLog), + Device("CameraTop", stream.video, stream.position, foraging.region), + Device("CameraEast", stream.video), + Device("CameraNest", stream.video), + Device("CameraNorth", stream.video), + Device("CameraPatch1", stream.video), + Device("CameraPatch2", stream.video), + Device("CameraSouth", stream.video), + Device("CameraWest", stream.video), + Device("Nest", foraging.weight), + Device("Patch1", foraging.patch), + Device("Patch2", foraging.patch), + ] +) -exp01 = DotMap([ - Device("SessionData", foraging.session), - Device("FrameTop", stream.video, stream.position), - Device("FrameEast", stream.video), - Device("FrameGate", stream.video), - Device("FrameNorth", stream.video), - Device("FramePatch1", stream.video), - Device("FramePatch2", stream.video), - Device("FrameSouth", stream.video), - Device("FrameWest", stream.video), - Device("Patch1", foraging.depletionFunction, stream.encoder, foraging.feeder), - Device("Patch2", foraging.depletionFunction, stream.encoder, foraging.feeder) -]) +exp01 = DotMap( + [ + Device("SessionData", foraging.session), + Device("FrameTop", stream.video, stream.position), + Device("FrameEast", stream.video), + Device("FrameGate", stream.video), + Device("FrameNorth", stream.video), + Device("FramePatch1", stream.video), + Device("FramePatch2", stream.video), + Device("FrameSouth", stream.video), + Device("FrameWest", stream.video), + Device("Patch1", foraging.depletionFunction, stream.encoder, foraging.feeder), + Device("Patch2", foraging.depletionFunction, stream.encoder, foraging.feeder), + ] +) -octagon01 = DotMap([ - Device("Metadata", stream.metadata), - Device("CameraTop", stream.video, stream.position), - Device("CameraColorTop", stream.video), - Device("ExperimentalMetadata", stream.subject_state), - Device("Photodiode", octagon.photodiode), - Device("OSC", octagon.OSC), - Device("TaskLogic", octagon.TaskLogic), - Device("Wall1", octagon.Wall), - Device("Wall2", octagon.Wall), - Device("Wall3", octagon.Wall), - Device("Wall4", octagon.Wall), - Device("Wall5", octagon.Wall), - Device("Wall6", octagon.Wall), - Device("Wall7", octagon.Wall), - Device("Wall8", octagon.Wall) -]) \ No newline at end of file +octagon01 = DotMap( + [ + Device("Metadata", stream.metadata), + Device("CameraTop", stream.video, stream.position), + Device("CameraColorTop", stream.video), + Device("ExperimentalMetadata", stream.subject_state), + Device("Photodiode", octagon.photodiode), + Device("OSC", octagon.OSC), + Device("TaskLogic", octagon.TaskLogic), + Device("Wall1", octagon.Wall), + Device("Wall2", octagon.Wall), + Device("Wall3", octagon.Wall), + Device("Wall4", octagon.Wall), + Device("Wall5", octagon.Wall), + Device("Wall6", octagon.Wall), + Device("Wall7", octagon.Wall), + Device("Wall8", octagon.Wall), + ] +) diff --git a/aeon/schema/foraging.py b/aeon/schema/foraging.py index 5580bf80..ffd8fdd9 100644 --- a/aeon/schema/foraging.py +++ b/aeon/schema/foraging.py @@ -1,8 +1,11 @@ -import pandas as _pd -import aeon.io.reader as _reader +from enum import Enum as _Enum + +import pandas as pd + import aeon.io.device as _device +import aeon.io.reader as _reader import aeon.schema.core as _stream -from enum import Enum as _Enum + class Area(_Enum): Null = 0 @@ -12,80 +15,94 @@ class Area(_Enum): Patch1 = 4 Patch2 = 5 + class _RegionReader(_reader.Harp): def __init__(self, pattern): - super().__init__(pattern, columns=['region']) + super().__init__(pattern, columns=["region"]) def read(self, file): data = super().read(file) - categorical = _pd.Categorical(data.region, categories=range(len(Area._member_names_))) - data['region'] = categorical.rename_categories(Area._member_names_) + categorical = pd.Categorical(data.region, categories=range(len(Area._member_names_))) + data["region"] = categorical.rename_categories(Area._member_names_) return data + class _PatchState(_reader.Csv): - """ - Extracts patch state data for linear depletion foraging patches. - + """Extracts patch state data for linear depletion foraging patches. + Columns: threshold (float): Distance to travel before the next pellet is delivered. d1 (float): y-intercept of the line specifying the depletion function. delta (float): Slope of the linear depletion function. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['threshold', 'd1', 'delta']) + super().__init__(pattern, columns=["threshold", "d1", "delta"]) + class _Weight(_reader.Harp): - """ - Extract weight measurements from an electronic weighing device. - + """Extract weight measurements from an electronic weighing device. + Columns: value (float): Absolute weight reading, in grams. stable (float): Normalized value in the range [0, 1] indicating how much the reading is stable. """ + def __init__(self, pattern): - super().__init__(pattern, columns=['value', 'stable']) + super().__init__(pattern, columns=["value", "stable"]) + def region(pattern): """Region tracking data for the specified camera.""" - return { "Region": _RegionReader(f"{pattern}_201_*") } + return {"Region": _RegionReader(f"{pattern}_201_*")} + def depletionFunction(pattern): """State of the linear depletion function for foraging patches.""" - return { "DepletionState": _PatchState(f"{pattern}_State_*") } + return {"DepletionState": _PatchState(f"{pattern}_State_*")} + def feeder(pattern): """Feeder commands and events.""" return _device.compositeStream(pattern, beam_break, deliver_pellet) + def beam_break(pattern): """Beam break events for pellet detection.""" - return { "BeamBreak": _reader.BitmaskEvent(f"{pattern}_32_*", 0x22, 'PelletDetected') } + return {"BeamBreak": _reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected")} + def deliver_pellet(pattern): """Pellet delivery commands.""" - return { "DeliverPellet": _reader.BitmaskEvent(f"{pattern}_35_*", 0x80, 'TriggerPellet') } + return {"DeliverPellet": _reader.BitmaskEvent(f"{pattern}_35_*", 0x80, "TriggerPellet")} + def patch(pattern): """Data streams for a patch.""" return _device.compositeStream(pattern, depletionFunction, _stream.encoder, feeder) + def weight(pattern): """Weight measurement data streams for a specific nest.""" return _device.compositeStream(pattern, weight_raw, weight_filtered, weight_subject) + def weight_raw(pattern): """Raw weight measurement for a specific nest.""" - return { "WeightRaw": _Weight(f"{pattern}_200_*") } + return {"WeightRaw": _Weight(f"{pattern}_200_*")} + def weight_filtered(pattern): """Filtered weight measurement for a specific nest.""" - return { "WeightFiltered": _Weight(f"{pattern}_202_*") } + return {"WeightFiltered": _Weight(f"{pattern}_202_*")} + def weight_subject(pattern): """Subject weight measurement for a specific nest.""" - return { "WeightSubject": _Weight(f"{pattern}_204_*") } + return {"WeightSubject": _Weight(f"{pattern}_204_*")} + def session(pattern): """Session metadata for Experiment 0.1.""" - return { pattern: _reader.Csv(f"{pattern}_2*", columns=['id','weight','event']) } + return {pattern: _reader.Csv(f"{pattern}_2*", columns=["id", "weight", "event"])} diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index b283c905..a792fac4 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -1,162 +1,190 @@ import aeon.io.reader as _reader -import aeon.io.device as _device -import aeon.schema.core as _stream + def photodiode(pattern): - return { "Photodiode": _reader.Harp(f"{pattern}_44_*", columns=['adc', 'encoder']) } + return {"Photodiode": _reader.Harp(f"{pattern}_44_*", columns=["adc", "encoder"])} + class OSC: @staticmethod def background_color(pattern): - return { "BackgroundColor": _reader.Csv(f"{pattern}_backgroundcolor_*", columns=['typetag', 'r', 'g', 'b', 'a']) } + return { + "BackgroundColor": _reader.Csv( + f"{pattern}_backgroundcolor_*", columns=["typetag", "r", "g", "b", "a"] + ) + } @staticmethod def change_subject_state(pattern): - return { "ChangeSubjectState": _reader.Csv(f"{pattern}_changesubjectstate_*", columns=['typetag', 'id', 'weight', 'event']) } + return { + "ChangeSubjectState": _reader.Csv( + f"{pattern}_changesubjectstate_*", columns=["typetag", "id", "weight", "event"] + ) + } @staticmethod def end_trial(pattern): - return { "EndTrial": _reader.Csv(f"{pattern}_endtrial_*", columns=['typetag', 'value']) } + return {"EndTrial": _reader.Csv(f"{pattern}_endtrial_*", columns=["typetag", "value"])} @staticmethod def slice(pattern): - return { "Slice": _reader.Csv(f"{pattern}_octagonslice_*", columns=[ - 'typetag', - 'wall_id', - 'r', 'g', 'b', 'a', - 'delay']) } + return { + "Slice": _reader.Csv( + f"{pattern}_octagonslice_*", columns=["typetag", "wall_id", "r", "g", "b", "a", "delay"] + ) + } @staticmethod def gratings_slice(pattern): - return { "GratingsSlice": _reader.Csv(f"{pattern}_octagongratingsslice_*", columns=[ - 'typetag', - 'wall_id', - 'contrast', - 'opacity', - 'spatial_frequency', - 'temporal_frequency', - 'angle', - 'delay']) } + return { + "GratingsSlice": _reader.Csv( + f"{pattern}_octagongratingsslice_*", + columns=[ + "typetag", + "wall_id", + "contrast", + "opacity", + "spatial_frequency", + "temporal_frequency", + "angle", + "delay", + ], + ) + } @staticmethod def poke(pattern): - return { "Poke": _reader.Csv(f"{pattern}_poke_*", columns=[ - 'typetag', - 'wall_id', - 'poke_id', - 'reward', - 'reward_interval', - 'delay', - 'led_delay']) } + return { + "Poke": _reader.Csv( + f"{pattern}_poke_*", + columns=[ + "typetag", + "wall_id", + "poke_id", + "reward", + "reward_interval", + "delay", + "led_delay", + ], + ) + } @staticmethod def response(pattern): - return { "Response": _reader.Csv(f"{pattern}_response_*", columns=[ - 'typetag', - 'wall_id', - 'poke_id', - 'response_time' ]) } + return { + "Response": _reader.Csv( + f"{pattern}_response_*", columns=["typetag", "wall_id", "poke_id", "response_time"] + ) + } @staticmethod def run_pre_trial_no_poke(pattern): - return { "RunPreTrialNoPoke": _reader.Csv(f"{pattern}_run_pre_no_poke_*", columns=[ - 'typetag', - 'wait_for_poke', - 'reward_iti', - 'timeout_iti', - 'pre_trial_duration', - 'activity_reset_flag' ]) } + return { + "RunPreTrialNoPoke": _reader.Csv( + f"{pattern}_run_pre_no_poke_*", + columns=[ + "typetag", + "wait_for_poke", + "reward_iti", + "timeout_iti", + "pre_trial_duration", + "activity_reset_flag", + ], + ) + } @staticmethod def start_new_session(pattern): - return { "StartNewSession": _reader.Csv(f"{pattern}_startnewsession_*", columns=['typetag', 'path' ]) } + return {"StartNewSession": _reader.Csv(f"{pattern}_startnewsession_*", columns=["typetag", "path"])} + class TaskLogic: @staticmethod def trial_initiation(pattern): - return { "TrialInitiation": _reader.Harp(f"{pattern}_1_*", columns=['trial_type']) } + return {"TrialInitiation": _reader.Harp(f"{pattern}_1_*", columns=["trial_type"])} @staticmethod def response(pattern): - return { "Response": _reader.Harp(f"{pattern}_2_*", columns=['wall_id', 'poke_id']) } + return {"Response": _reader.Harp(f"{pattern}_2_*", columns=["wall_id", "poke_id"])} @staticmethod def pre_trial(pattern): - return { "PreTrialState": _reader.Harp(f"{pattern}_3_*", columns=['state']) } + return {"PreTrialState": _reader.Harp(f"{pattern}_3_*", columns=["state"])} @staticmethod def inter_trial_interval(pattern): - return { "InterTrialInterval": _reader.Harp(f"{pattern}_4_*", columns=['state']) } + return {"InterTrialInterval": _reader.Harp(f"{pattern}_4_*", columns=["state"])} @staticmethod def slice_onset(pattern): - return { "SliceOnset": _reader.Harp(f"{pattern}_10_*", columns=['wall_id']) } + return {"SliceOnset": _reader.Harp(f"{pattern}_10_*", columns=["wall_id"])} @staticmethod def draw_background(pattern): - return { "DrawBackground": _reader.Harp(f"{pattern}_11_*", columns=['state']) } + return {"DrawBackground": _reader.Harp(f"{pattern}_11_*", columns=["state"])} @staticmethod def gratings_slice_onset(pattern): - return { "GratingsSliceOnset": _reader.Harp(f"{pattern}_12_*", columns=['wall_id']) } + return {"GratingsSliceOnset": _reader.Harp(f"{pattern}_12_*", columns=["wall_id"])} + class Wall: @staticmethod def beam_break0(pattern): - return { "BeamBreak0": _reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=['state']) } + return {"BeamBreak0": _reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=["state"])} @staticmethod def beam_break1(pattern): - return { "BeamBreak1": _reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=['state']) } + return {"BeamBreak1": _reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=["state"])} @staticmethod def beam_break2(pattern): - return { "BeamBreak2": _reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=['state']) } + return {"BeamBreak2": _reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=["state"])} @staticmethod def set_led0(pattern): - return { "SetLed0": _reader.BitmaskEvent(f"{pattern}_34_*", 0x1, 'Set') } + return {"SetLed0": _reader.BitmaskEvent(f"{pattern}_34_*", 0x1, "Set")} @staticmethod def set_led1(pattern): - return { "SetLed1": _reader.BitmaskEvent(f"{pattern}_34_*", 0x2, 'Set') } + return {"SetLed1": _reader.BitmaskEvent(f"{pattern}_34_*", 0x2, "Set")} @staticmethod def set_led2(pattern): - return { "SetLed2": _reader.BitmaskEvent(f"{pattern}_34_*", 0x4, 'Set') } + return {"SetLed2": _reader.BitmaskEvent(f"{pattern}_34_*", 0x4, "Set")} @staticmethod def set_valve0(pattern): - return { "SetValve0": _reader.BitmaskEvent(f"{pattern}_34_*", 0x8, 'Set') } + return {"SetValve0": _reader.BitmaskEvent(f"{pattern}_34_*", 0x8, "Set")} @staticmethod def set_valve1(pattern): - return { "SetValve1": _reader.BitmaskEvent(f"{pattern}_34_*", 0x10, 'Set') } + return {"SetValve1": _reader.BitmaskEvent(f"{pattern}_34_*", 0x10, "Set")} @staticmethod def set_valve2(pattern): - return { "SetValve2": _reader.BitmaskEvent(f"{pattern}_34_*", 0x20, 'Set') } + return {"SetValve2": _reader.BitmaskEvent(f"{pattern}_34_*", 0x20, "Set")} @staticmethod def clear_led0(pattern): - return { "ClearLed0": _reader.BitmaskEvent(f"{pattern}_35_*", 0x1, 'Clear') } + return {"ClearLed0": _reader.BitmaskEvent(f"{pattern}_35_*", 0x1, "Clear")} @staticmethod def clear_led1(pattern): - return { "ClearLed1": _reader.BitmaskEvent(f"{pattern}_35_*", 0x2, 'Clear') } + return {"ClearLed1": _reader.BitmaskEvent(f"{pattern}_35_*", 0x2, "Clear")} @staticmethod def clear_led2(pattern): - return { "ClearLed2": _reader.BitmaskEvent(f"{pattern}_35_*", 0x4, 'Clear') } + return {"ClearLed2": _reader.BitmaskEvent(f"{pattern}_35_*", 0x4, "Clear")} @staticmethod def clear_valve0(pattern): - return { "ClearValve0": _reader.BitmaskEvent(f"{pattern}_35_*", 0x8, 'Clear') } + return {"ClearValve0": _reader.BitmaskEvent(f"{pattern}_35_*", 0x8, "Clear")} @staticmethod def clear_valve1(pattern): - return { "ClearValve1": _reader.BitmaskEvent(f"{pattern}_35_*", 0x10, 'Clear') } + return {"ClearValve1": _reader.BitmaskEvent(f"{pattern}_35_*", 0x10, "Clear")} @staticmethod def clear_valve2(pattern): - return { "ClearValve2": _reader.BitmaskEvent(f"{pattern}_35_*", 0x20, 'Clear') } \ No newline at end of file + return {"ClearValve2": _reader.BitmaskEvent(f"{pattern}_35_*", 0x20, "Clear")} diff --git a/aeon/schema/social.py b/aeon/schema/social.py index e5d6efec..bdacfc8b 100644 --- a/aeon/schema/social.py +++ b/aeon/schema/social.py @@ -21,7 +21,9 @@ class (int): Int ID of a subject in the environment. x (float): X-coordinate of the bodypart. y (float): Y-coordinate of the bodypart. """ - def __init__(self, pattern: str, extension: str="bin"): + + def __init__(self, pattern: str, extension: str = "bin"): + """Pose reader constructor.""" # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None, extension=extension) @@ -31,9 +33,15 @@ def read( """Reads data from the Harp-binarized tracking file.""" # Get config file from `file`, then bodyparts from config file. model_dir = Path(file.stem.replace("_", "/")).parent +<<<<<<< HEAD config_file_dir = ceph_proc_dir / model_dir if not config_file_dir.exists(): raise FileNotFoundError(f"Cannot find model dir {config_file_dir}") +======= + # `ceph_proc_dir` typically + config_file_dir = Path(ceph_proc_dir) / model_dir + assert config_file_dir.exists(), f"Cannot find model dir {config_file_dir}" +>>>>>>> b9a1e3f... Blackened and ruffed config_file = get_config_file(config_file_dir) parts = self.get_bodyparts(config_file) @@ -44,6 +52,7 @@ def read( self.columns = columns data = super().read(file) +<<<<<<< HEAD # Drop any repeat parts. unique_parts, unique_idxs = np.unique(parts, return_index=True) repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs) @@ -54,6 +63,8 @@ def read( data = data.iloc[:, keep_part_col_idxs] parts = unique_parts +======= +>>>>>>> b9a1e3f... Blackened and ruffed # Set new columns, and reformat `data`. n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts @@ -91,13 +102,21 @@ def get_config_file( """Returns the config file from a model's config directory.""" if config_file_names is None: config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) +<<<<<<< HEAD config_file = None +======= + config_file = Path() +>>>>>>> b9a1e3f... Blackened and ruffed for f in config_file_names: if (config_file_dir / f).exists(): config_file = config_file_dir / f break +<<<<<<< HEAD if config_file is None: raise FileNotFoundError(f"Cannot find config file in {config_file_dir}") +======= + assert config_file.is_file(), f"Cannot find config file in {config_file_dir}" +>>>>>>> b9a1e3f... Blackened and ruffed return config_file diff --git a/aeon/util.py b/aeon/util.py index e5b2baa0..aed34803 100644 --- a/aeon/util.py +++ b/aeon/util.py @@ -1,9 +1,9 @@ """Utility functions.""" -from typing import Union, Any +from typing import Any -def find_nested_key(obj: Union[dict, list], key: str) -> Any: +def find_nested_key(obj: dict | list, key: str) -> Any: """Returns the value of the first found nested key.""" if isinstance(obj, dict): if v := obj.get(key): # found it! diff --git a/pyproject.toml b/pyproject.toml index 075e1ef4..90df02f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,8 +102,8 @@ exclude = ''' select = ["E", "W", "F", "I", "D", "UP", "S", "B", "A", "C4", "ICN", "PIE", "PT", "SIM", "PL"] line-length = 108 ignore = [ - "E201", "E202", "E203", "E231", "E731", "E702", - "S101", + "E201", "E202", "E203", "E231", "E731", "E702", + "S101", "PT013", "PLR0912", "PLR0913", "PLR0915" ] @@ -113,6 +113,7 @@ extend-exclude = [".git", ".github", ".idea", ".vscode"] convention = "google" [tool.pyright] +reportMissingImports = "none" reportImportCycles = "error" reportUnusedImport = "error" reportUnusedClass = "error" diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 48986830..c09ab322 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -1,9 +1,13 @@ +<<<<<<< HEAD from pathlib import Path +======= +>>>>>>> b9a1e3f... Blackened and ruffed import pandas as pd import pytest from pytest import mark +<<<<<<< HEAD import aeon from aeon.schema.dataset import exp02 @@ -14,34 +18,63 @@ @mark.api def test_load_start_only(): data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder, start=pd.Timestamp("2022-06-06T13:00:49")) +======= +import aeon.io.api as aeon +from aeon.schema.dataset import exp02 + + +@mark.api +def test_load_start_only(): + data = aeon.load( + "./tests/data/nonmonotonic", exp02.Patch2.Encoder, start=pd.Timestamp("2022-06-06T13:00:49") + ) +>>>>>>> b9a1e3f... Blackened and ruffed assert len(data) > 0 @mark.api def test_load_end_only(): data = aeon.load( +<<<<<<< HEAD nonmonotonic_path, exp02.Patch2.Encoder, end=pd.Timestamp("2022-06-06T13:00:49") +======= + "./tests/data/nonmonotonic", exp02.Patch2.Encoder, end=pd.Timestamp("2022-06-06T13:00:49") +>>>>>>> b9a1e3f... Blackened and ruffed ) assert len(data) > 0 @mark.api def test_load_filter_nonchunked(): +<<<<<<< HEAD data = aeon.load( nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") ) +======= + data = aeon.load("./tests/data/nonmonotonic", exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00")) +>>>>>>> b9a1e3f... Blackened and ruffed assert len(data) > 0 @mark.api def test_load_monotonic(): +<<<<<<< HEAD data = aeon.load(monotonic_path, exp02.Patch2.Encoder) assert len(data) > 0 and data.index.is_monotonic_increasing +======= + data = aeon.load("./tests/data/monotonic", exp02.Patch2.Encoder) + assert data.index.is_monotonic_increasing +>>>>>>> b9a1e3f... Blackened and ruffed + @mark.api def test_load_nonmonotonic(): +<<<<<<< HEAD data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder) +======= + data = aeon.load("./tests/data/nonmonotonic", exp02.Patch2.Encoder) +>>>>>>> b9a1e3f... Blackened and ruffed assert not data.index.is_monotonic_increasing