Skip to content

Commit

Permalink
Blackened and ruffed
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbhagatio committed Sep 2, 2023
1 parent f1ab685 commit 7445a0a
Show file tree
Hide file tree
Showing 35 changed files with 854 additions and 1,001 deletions.
55 changes: 29 additions & 26 deletions aeon/analysis/movies.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import cv2
import math

import cv2
import numpy as np
import pandas as pd
import aeon.io.video as video

from aeon.io import video


def gridframes(frames, width, height, shape=None):
'''
Arranges a set of frames into a grid layout with the specified
"""Arranges a set of frames into a grid layout with the specified
pixel dimensions and shape.
:param list frames: A list of frames to include in the grid layout.
Expand All @@ -16,15 +18,15 @@ def gridframes(frames, width, height, shape=None):
Either the number of frames to include, or the number of rows and columns
in the output grid image layout.
:return: A new image containing the arrangement of the frames in a grid.
'''
"""
if shape is None:
shape = len(frames)
if type(shape) not in [list,tuple]:
if type(shape) not in [list, tuple]:
shape = math.ceil(math.sqrt(shape))
shape = (shape, shape)

dsize = (height, width, 3)
cellsize = (height // shape[0], width // shape[1],3)
cellsize = (height // shape[0], width // shape[1], 3)
grid = np.zeros(dsize, dtype=np.uint8)
for i in range(shape[0]):
for j in range(shape[1]):
Expand All @@ -39,19 +41,20 @@ def gridframes(frames, width, height, shape=None):
grid[i0:i1, j0:j1] = cv2.resize(frame, (cellsize[1], cellsize[0]))
return grid


def averageframes(frames):
"""Returns the average of the specified collection of frames."""
return cv2.convertScaleAbs(sum(np.multiply(1 / len(frames), frames)))


def groupframes(frames, n, fun):
'''
Applies the specified function to each group of n-frames.
"""Applies the specified function to each group of n-frames.
:param iterable frames: A sequence of frames to process.
:param int n: The number of frames in each group.
:param callable fun: The function used to process each group of frames.
:return: An iterable returning the results of applying the function to each group.
'''
"""
i = 0
group = []
for frame in frames:
Expand All @@ -61,9 +64,9 @@ def groupframes(frames, n, fun):
group.clear()
i = i + 1


def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)):
'''
Split video data around the specified sequence of event timestamps.
"""Split video data around the specified sequence of event timestamps.
:param DataFrame data:
A pandas DataFrame where each row specifies video acquisition path and frame number.
Expand All @@ -72,7 +75,7 @@ def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)):
:param Timedelta after: The right offset from each timestamp used to clip the data.
:return:
A pandas DataFrame containing the frames, clip and sequence numbers for each event timestamp.
'''
"""
if before is not pd.Timedelta:
before = pd.Timedelta(before)
if after is not pd.Timedelta:
Expand All @@ -81,30 +84,30 @@ def triggerclip(data, events, before=pd.Timedelta(0), after=pd.Timedelta(0)):
events = events.index

clips = []
for i,index in enumerate(events):
clip = data.loc[(index-before):(index+after)].copy()
clip['frame_sequence'] = list(range(len(clip)))
clip['clip_sequence'] = i
for i, index in enumerate(events):
clip = data.loc[(index - before) : (index + after)].copy()
clip["frame_sequence"] = list(range(len(clip)))
clip["clip_sequence"] = i
clips.append(clip)
return pd.concat(clips)


def collatemovie(clipdata, fun):
'''
Collates a set of video clips into a single movie using the specified aggregation function.
"""Collates a set of video clips into a single movie using the specified aggregation function.
:param DataFrame clipdata:
A pandas DataFrame where each row specifies video path, frame number, clip and sequence number.
This DataFrame can be obtained from the output of the triggerclip function.
:param callable fun: The aggregation function used to process the frames in each clip.
:return: The sequence of processed frames representing the collated movie.
'''
clipcount = len(clipdata.groupby('clip_sequence').frame_sequence.count())
allframes = video.frames(clipdata.sort_values(by=['frame_sequence', 'clip_sequence']))
"""
clipcount = len(clipdata.groupby("clip_sequence").frame_sequence.count())
allframes = video.frames(clipdata.sort_values(by=["frame_sequence", "clip_sequence"]))
return groupframes(allframes, clipcount, fun)


def gridmovie(clipdata, width, height, shape=None):
'''
Collates a set of video clips into a grid movie with the specified pixel dimensions
"""Collates a set of video clips into a grid movie with the specified pixel dimensions
and grid layout.
:param DataFrame clipdata:
Expand All @@ -116,5 +119,5 @@ def gridmovie(clipdata, width, height, shape=None):
Either the number of frames to include, or the number of rows and columns
in the output grid movie layout.
:return: The sequence of processed frames representing the collated grid movie.
'''
return collatemovie(clipdata, lambda g:gridframes(g, width, height, shape))
"""
return collatemovie(clipdata, lambda g: gridframes(g, width, height, shape))
25 changes: 8 additions & 17 deletions aeon/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import math

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import colors
from matplotlib.collections import LineCollection

from aeon.analysis.utils import *


def heatmap(position, frequency, ax=None, **kwargs):
"""
Draw a heatmap of time spent in each location from specified position data and sampling frequency.
"""Draw a heatmap of time spent in each location from specified position data and sampling frequency.
:param Series position: A series of position data containing x and y coordinates.
:param number frequency: The sampling frequency for the position data.
Expand All @@ -33,8 +32,7 @@ def heatmap(position, frequency, ax=None, **kwargs):


def circle(x, y, radius, fmt=None, ax=None, **kwargs):
"""
Plot a circle centered at the given x, y position with the specified radius.
"""Plot a circle centered at the given x, y position with the specified radius.
:param number x: The x-component of the circle center.
:param number y: The y-component of the circle center.
Expand Down Expand Up @@ -62,8 +60,7 @@ def rateplot(
ax=None,
**kwargs,
):
"""
Plot the continuous event rate and raster of a discrete event sequence, given the specified
"""Plot the continuous event rate and raster of a discrete event sequence, given the specified
window size and sampling frequency.
:param Series events: The discrete sequence of events.
Expand All @@ -79,9 +76,7 @@ def rateplot(
:param Axes, optional ax: The Axes on which to draw the rate plot and raster.
"""
label = kwargs.pop("label", None)
eventrate = rate(
events, window, frequency, weight, start, end, smooth=smooth, center=center
)
eventrate = rate(events, window, frequency, weight, start, end, smooth=smooth, center=center)
if ax is None:
ax = plt.gca()
ax.plot(
Expand All @@ -90,14 +85,11 @@ def rateplot(
label=label,
**kwargs,
)
ax.vlines(
sessiontime(events.index, eventrate.index[0]), -0.2, -0.1, linewidth=1, **kwargs
)
ax.vlines(sessiontime(events.index, eventrate.index[0]), -0.2, -0.1, linewidth=1, **kwargs)


def set_ymargin(ax, bottom, top):
"""
Set the vertical margins of the specified Axes.
"""Set the vertical margins of the specified Axes.
:param Axes ax: The Axes for which to specify the vertical margin.
:param number bottom: The size of the bottom margin.
Expand All @@ -121,8 +113,7 @@ def colorline(
ax=None,
**kwargs,
):
"""
Plot a dynamically colored line on the specified Axes.
"""Plot a dynamically colored line on the specified Axes.
:param array-like x, y: The horizontal / vertical coordinates of the data points.
:param array-like, optional z:
Expand Down
89 changes: 49 additions & 40 deletions aeon/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
import pandas as pd


def distancetravelled(angle, radius=4.0):
'''
Calculates the total distance travelled on the wheel, by taking into account
"""Calculates the total distance travelled on the wheel, by taking into account
its radius and the total number of turns in both directions across time.
:param Series angle: A series of magnetic encoder measurements.
:param float radius: The radius of the wheel, in metric units.
:return: The total distance travelled on the wheel, in metric units.
'''
"""
maxvalue = int(np.iinfo(np.uint16).max >> 2)
jumpthreshold = maxvalue // 2
turns = angle.astype(int).diff()
Expand All @@ -20,9 +20,9 @@ def distancetravelled(angle, radius=4.0):
distance = distance - distance[0]
return distance

def visits(data, onset='Enter', offset='Exit'):
'''
Computes duration, onset and offset times from paired events. Allows for missing data

def visits(data, onset="Enter", offset="Exit"):
"""Computes duration, onset and offset times from paired events. Allows for missing data
by trying to match event onset times with subsequent offset times. If the match fails,
event offset metadata is filled with NaN. Any additional metadata columns in the data
frame will be paired and included in the output.
Expand All @@ -31,45 +31,45 @@ def visits(data, onset='Enter', offset='Exit'):
:param str, optional onset: The label used to identify event onsets.
:param str, optional offset: The label used to identify event offsets.
:return: A pandas data frame containing duration and metadata for each visit.
'''
"""
lonset = onset.lower()
loffset = offset.lower()
lsuffix = '_{0}'.format(lonset)
rsuffix = '_{0}'.format(loffset)
id_onset = 'id' + lsuffix
event_onset = 'event' + lsuffix
event_offset = 'event' + rsuffix
time_onset = 'time' + lsuffix
time_offset = 'time' + rsuffix
lsuffix = f"_{lonset}"
rsuffix = f"_{loffset}"
id_onset = "id" + lsuffix
event_onset = "event" + lsuffix
event_offset = "event" + rsuffix
time_onset = "time" + lsuffix
time_offset = "time" + rsuffix

# find all possible onset / offset pairs
data = data.reset_index()
data_onset = data[data.event == onset]
data_offset = data[data.event == offset]
data = pd.merge(data_onset, data_offset, on='id', how='left', suffixes=[lsuffix, rsuffix])
data = pd.merge(data_onset, data_offset, on="id", how="left", suffixes=[lsuffix, rsuffix])

# valid pairings have the smallest positive duration
data['duration'] = data[time_offset] - data[time_onset]
data["duration"] = data[time_offset] - data[time_onset]
valid_visits = data[data.duration >= pd.Timedelta(0)]
data = data.iloc[valid_visits.groupby([time_onset, 'id']).duration.idxmin()]
data = data.iloc[valid_visits.groupby([time_onset, "id"]).duration.idxmin()]
data = data[data.duration > pd.Timedelta(0)]

# duplicate offsets indicate missing data from previous pairing
missing_data = data.duplicated(subset=time_offset, keep='last')
missing_data = data.duplicated(subset=time_offset, keep="last")
if missing_data.any():
data.loc[missing_data, ['duration'] + [name for name in data.columns if rsuffix in name]] = pd.NA
data.loc[missing_data, ["duration"] + [name for name in data.columns if rsuffix in name]] = pd.NA

# rename columns and sort data
data.rename({ time_onset:lonset, id_onset:'id', time_offset:loffset}, axis=1, inplace=True)
data = data[['id'] + [name for name in data.columns if '_' in name] + [lonset, loffset, 'duration']]
data.rename({time_onset: lonset, id_onset: "id", time_offset: loffset}, axis=1, inplace=True)
data = data[["id"] + [name for name in data.columns if "_" in name] + [lonset, loffset, "duration"]]
data.drop([event_onset, event_offset], axis=1, inplace=True)
data.sort_index(inplace=True)
data.reset_index(drop=True, inplace=True)
return data


def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None, center=False):
'''
Computes the continuous event rate from a discrete event sequence, given the specified
"""Computes the continuous event rate from a discrete event sequence, given the specified
window size and sampling frequency.
:param Series events: The discrete sequence of events.
Expand All @@ -83,52 +83,61 @@ def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None,
The size of the smoothing kernel applied to the continuous rate output.
:param bool, optional center: Specifies whether to center the convolution kernels.
:return: A Series containing the continuous event rate over time.
'''
"""
counts = pd.Series(weight, events.index)
if start is not None and start < events.index[0]:
counts.loc[start] = 0
if end is not None and end > events.index[-1]:
counts.loc[end] = 0
counts.sort_index(inplace=True)
counts = counts.resample(pd.Timedelta(1 / frequency, 's')).sum()
rate = counts.rolling(window,center=center).sum()
return rate.rolling(window if smooth is None else smooth,center=center).mean()
counts = counts.resample(pd.Timedelta(1 / frequency, "s")).sum()
rate = counts.rolling(window, center=center).sum()
return rate.rolling(window if smooth is None else smooth, center=center).mean()


def get_events_rates(events, window_len_sec, frequency, unit_len_sec=60, start=None, end=None, smooth=None, center=False):
def get_events_rates(
events, window_len_sec, frequency, unit_len_sec=60, start=None, end=None, smooth=None, center=False
):
# events is an array with the time (in seconds) of event occurence
# window_len_sec is the size of the window over which the event rate is estimated
# unit_len_sec is the length of one sample point
window_len_sec_str = "{:d}S".format(window_len_sec)
window_len_sec_str = f"{window_len_sec:d}S"
counts = pd.Series(1.0, events.index)
if start is not None and start < events.index[0]:
counts.loc[start] = 0
if end is not None and end > events.index[-1]:
counts.loc[end] = 0
counts.sort_index(inplace=True)
counts_resampled = counts.resample(frequency).sum()
counts_rolled = counts_resampled.rolling(window_len_sec_str,center=center).sum()*unit_len_sec/window_len_sec
counts_rolled_smoothed = counts_rolled.rolling(window_len_sec_str if smooth is None else smooth, center=center).mean()
counts_rolled = (
counts_resampled.rolling(window_len_sec_str, center=center).sum() * unit_len_sec / window_len_sec
)
counts_rolled_smoothed = counts_rolled.rolling(
window_len_sec_str if smooth is None else smooth, center=center
).mean()
return counts_rolled_smoothed


def sessiontime(index, start=None):
"""Converts absolute to relative time, with optional reference starting time."""
if (start is None):
if start is None:
start = index[0]
return (index-start).total_seconds() / 60
return (index - start).total_seconds() / 60


def distance(position, target):
"""Computes the euclidean distance to a specified target."""
return np.sqrt(np.square(position[['x','y']] - target).sum(axis=1))
return np.sqrt(np.square(position[["x", "y"]] - target).sum(axis=1))


def activepatch(wheel, in_patch):
'''
Computes a decision boundary for when a patch is active based on wheel movement.
"""Computes a decision boundary for when a patch is active based on wheel movement.
:param Series wheel: A pandas Series containing the cumulative distance travelled on the wheel.
:param Series in_patch: A Series of type bool containing whether the specified patch may be active.
:return: A pandas Series specifying for each timepoint whether the patch is active.
'''
"""
exit_patch = in_patch.astype(np.int8).diff() < 0
in_wheel = (wheel.diff().rolling('1s').sum() > 1).reindex(in_patch.index, method='pad')
in_wheel = (wheel.diff().rolling("1s").sum() > 1).reindex(in_patch.index, method="pad")
epochs = exit_patch.cumsum()
return in_wheel.groupby(epochs).apply(lambda x:x.cumsum()) > 0
return in_wheel.groupby(epochs).apply(lambda x: x.cumsum()) > 0
Loading

0 comments on commit 7445a0a

Please sign in to comment.