Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New ingestion_schemas + MANY minor fixes and improvements #438

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
833f9e9
feat: add one-off logic to ingest fullpose data for social02
ttngu207 Aug 8, 2024
0558cbb
Merge branch 'datajoint_pipeline' into dev_fullpose_for_social02
ttngu207 Aug 20, 2024
e9c7fa2
Create reingest_fullpose_sleap_data.py
ttngu207 Aug 21, 2024
38b44e0
Allow reading model metadata from local folder
glopesdev Sep 26, 2024
028ffc5
Avoid iterating over None
glopesdev Sep 26, 2024
25b7195
Avoid iterating over the config file twice
glopesdev Sep 26, 2024
f77ac1d
Avoid mixing dtypes with conditional assignment
glopesdev Sep 26, 2024
ac2aa13
Remove whitespace on blank line
glopesdev Sep 26, 2024
caf3ce1
Use replace function instead of explicit loop
glopesdev Sep 27, 2024
93428c8
Improve error logic when model metadata not found
glopesdev Sep 27, 2024
00c1cca
Test loading poses with local model metadata
glopesdev Sep 27, 2024
6b32583
Use all components other than time and device name
glopesdev Sep 27, 2024
010fdb9
Add regression test for poses with register prefix
glopesdev Oct 2, 2024
0a88b79
Infer base prefix from stream search pattern
glopesdev Oct 2, 2024
f925d75
Use full identity likelihood vectors in test data
glopesdev Oct 2, 2024
83cd905
Merge pull request #421 from SainsburyWellcomeCentre/gl-issue-418
glopesdev Oct 3, 2024
36ee97a
Update worker.py
ttngu207 Oct 10, 2024
b54e1c3
new readers and schemas for reduced data storage in db
jkbhagatio Oct 10, 2024
d6cf52f
updated tests
jkbhagatio Oct 10, 2024
f12e359
cleaned up linting for ruff
jkbhagatio Oct 10, 2024
daf6224
updated pandas and changed S to s lmao
jkbhagatio Oct 10, 2024
6d798b8
chore: code cleanup
ttngu207 Oct 15, 2024
ea3c2ef
Merge remote-tracking branch 'upstream/ingestion_readers_schemas' int…
ttngu207 Oct 15, 2024
697c0a8
chore: delete the obsolete `dataset` (replaced by `schemas`)
ttngu207 Oct 15, 2024
2ef32c3
chore: clean up `load_metadata`
ttngu207 Oct 16, 2024
d5bd0fe
feat(ingestion): use new `ingestion_schemas`
ttngu207 Oct 16, 2024
8725e8f
feat(streams): update streams with new ingestion_schemas
ttngu207 Oct 16, 2024
0f210e1
fix(ingestion_schemas): downsampling Encoder
ttngu207 Oct 16, 2024
d365bcd
fix(ingestion_schemas): minor fix in `_Encoder`, calling `super()` init
ttngu207 Oct 18, 2024
cb90843
fix(harp reader): remove rows where the index is zero
ttngu207 Oct 18, 2024
9c7e9d9
fix(BlockForaging): bugfix in col rename
ttngu207 Oct 18, 2024
0a9c1e1
fix(block_analysis): bugfix in extracting `subject_in_patch` time
ttngu207 Oct 21, 2024
4020900
feat(fetch_stream): flag to round to microseconds
ttngu207 Oct 21, 2024
566c3ed
fix(block_analysis): bugfix `in_patch_timestamps`
ttngu207 Oct 21, 2024
28e39c1
Update reingest_fullpose_sleap_data.py
ttngu207 Oct 21, 2024
41a248d
Create reingest_fullpose_sleap_data.py
ttngu207 Aug 21, 2024
cda41fb
Update reingest_fullpose_sleap_data.py
ttngu207 Oct 21, 2024
6f7f541
Merge branch 'dev_fullpose_for_social02' into datajoint_pipeline
ttngu207 Oct 21, 2024
f783067
fix: `social_02.Pose03` in `ingestion_schemas` only
ttngu207 Oct 21, 2024
3e59db8
Update reingest_fullpose_sleap_data.py
ttngu207 Oct 21, 2024
64900ad
Update reingest_fullpose_sleap_data.py
ttngu207 Oct 21, 2024
538e4e5
feat(tracking): add `BlobPositionTracking`
ttngu207 Oct 22, 2024
290fe4e
fix(block_analysis): various fixes and code improvements
ttngu207 Oct 22, 2024
fb18016
fix: improve logic to search for chunks in a given block
ttngu207 Oct 22, 2024
8f2fffc
feat(script): add script `sync_ingested_and_raw_epochs`
ttngu207 Oct 23, 2024
8762fcf
fix(sync_ingested_and_raw_epochs): minor code cleanup
ttngu207 Oct 24, 2024
9078085
fix(BlockSubjectAnalysis): handle edge case where the encoder data ar…
ttngu207 Oct 24, 2024
ebecb00
chore: minor fixes to address PR review comments
ttngu207 Oct 30, 2024
b0952eb
fix: address PR comments
ttngu207 Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion aeon/dj_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import logging
import os
import uuid

Expand Down Expand Up @@ -30,11 +31,17 @@ def dict_to_uuid(key) -> uuid.UUID:
return uuid.UUID(hex=hashed.hexdigest())


def fetch_stream(query, drop_pk=True):
def fetch_stream(query, drop_pk=True, round_microseconds=True):
"""Fetches data from a Stream table based on a query and returns it as a DataFrame.
Provided a query containing data from a Stream table,
fetch and aggregate the data into one DataFrame indexed by "time"
Args:
query (datajoint.Query): A query object containing data from a Stream table
drop_pk (bool, optional): Drop primary key columns. Defaults to True.
round_microseconds (bool, optional): Round timestamps to microseconds. Defaults to True.
(this is important as timestamps in mysql is only accurate to microseconds)
"""
df = (query & "sample_count > 0").fetch(format="frame").reset_index()
cols2explode = [
Expand All @@ -47,6 +54,10 @@ def fetch_stream(query, drop_pk=True):
df.set_index("time", inplace=True)
df.sort_index(inplace=True)
df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False)
if not df.empty and round_microseconds:
logging.warning("Rounding timestamps to microseconds is now enabled by default."
" To disable, set round_microseconds=False.")
df.index = df.index.round("us")
return df


Expand Down
10 changes: 7 additions & 3 deletions aeon/dj_pipeline/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aeon.dj_pipeline.utils import paths
from aeon.io import api as io_api
from aeon.io import reader as io_reader
from aeon.schema import schemas as aeon_schemas
from aeon.schema import ingestion_schemas as aeon_schemas

logger = dj.logger
schema = dj.schema(get_schema_name("acquisition"))
Expand Down Expand Up @@ -646,10 +646,14 @@ def _match_experiment_directory(experiment_name, path, directories):

def create_chunk_restriction(experiment_name, start_time, end_time):
"""Create a time restriction string for the chunks between the specified "start" and "end" times."""
exp_key = {"experiment_name": experiment_name}
start_restriction = f'"{start_time}" BETWEEN chunk_start AND chunk_end'
end_restriction = f'"{end_time}" BETWEEN chunk_start AND chunk_end'
start_query = Chunk & {"experiment_name": experiment_name} & start_restriction
end_query = Chunk & {"experiment_name": experiment_name} & end_restriction
start_query = Chunk & exp_key & start_restriction
end_query = Chunk & exp_key & end_restriction
if not end_query:
# No chunk contains the end time, so we need to find the last chunk that starts before the end time
end_query = Chunk & exp_key & f'chunk_end BETWEEN "{start_time}" AND "{end_time}"'
if not (start_query and end_query):
raise ValueError(f"No Chunk found between {start_time} and {end_time}")
time_restriction = (
Expand Down
158 changes: 102 additions & 56 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class BlockDetection(dj.Computed):
-> acquisition.Environment
"""

key_source = acquisition.Environment - {"experiment_name": "social0.1-aeon3"}

def make(self, key):
"""On a per-chunk basis, check for the presence of new block, insert into Block table.
Expand Down Expand Up @@ -88,8 +90,7 @@ def make(self, key):
blocks_df = block_state_df[block_state_df.pellet_ct == 0]
# account for the double 0s - find any 0s that are within 1 second of each other, remove the 2nd one
double_0s = blocks_df.index.to_series().diff().dt.total_seconds() < 1
# find the indices of the 2nd 0s and remove
double_0s = double_0s.shift(-1).fillna(False)
# keep the first 0s
blocks_df = blocks_df[~double_0s]

block_entries = []
Expand Down Expand Up @@ -144,8 +145,8 @@ class Patch(dj.Part):
wheel_timestamps: longblob
patch_threshold: longblob
patch_threshold_timestamps: longblob
patch_rate: float
patch_offset: float
patch_rate=null: float
patch_offset=null: float
"""

class Subject(dj.Part):
Expand Down Expand Up @@ -181,17 +182,27 @@ def make(self, key):
streams.UndergroundFeederDepletionState,
streams.UndergroundFeederDeliverPellet,
streams.UndergroundFeederEncoder,
tracking.SLEAPTracking,
)
for streams_table in streams_tables:
if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys):
raise ValueError(
f"BlockAnalysis Not Ready - {streams_table.__name__} not yet fully ingested for block: {key}. Skipping (to retry later)..."
)

# Check if SLEAPTracking is ready, if not, see if BlobPosition can be used instead
use_blob_position = False
if len(tracking.SLEAPTracking & chunk_keys) < len(tracking.SLEAPTracking.key_source & chunk_keys):
if len(tracking.BlobPosition & chunk_keys) < len(tracking.BlobPosition.key_source & chunk_keys):
raise ValueError(
f"BlockAnalysis Not Ready - SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. Skipping (to retry later)..."
)
else:
use_blob_position = True

# Patch data - TriggerPellet, DepletionState, Encoder (distancetravelled)
# For wheel data, downsample to 10Hz
final_encoder_fs = 10
# For wheel data, downsample to 50Hz
final_encoder_hz = 50
freq = 1 / final_encoder_hz * 1e3 # in ms

maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end)

Expand Down Expand Up @@ -233,51 +244,52 @@ def make(self, key):
encoder_df, maintenance_period, block_end, dropna=True
)

if depletion_state_df.empty:
raise ValueError(f"No depletion state data found for block {key} - patch: {patch_name}")

encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle)

if len(depletion_state_df.rate.unique()) > 1:
# multiple patch rates per block is unexpected, log a note and pick the first rate to move forward
AnalysisNote.insert1(
{
"note_timestamp": datetime.utcnow(),
"note_type": "Multiple patch rates",
"note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}",
}
)
# if all dataframes are empty, skip
if pellet_ts_threshold_df.empty and depletion_state_df.empty and encoder_df.empty:
continue

patch_rate = depletion_state_df.rate.iloc[0]
patch_offset = depletion_state_df.offset.iloc[0]
# handles patch rate value being INF
patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate
if encoder_df.empty:
encoder_df["distance_travelled"] = 0
else:
# -1 is for placement of magnetic encoder, where wheel movement actually decreases encoder
encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment saying something like -1 is for placement of magnetic encoder, where wheel movement actually decreases encoder value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

encoder_df = encoder_df.resample(f"{freq}ms").first()

if not depletion_state_df.empty:
if len(depletion_state_df.rate.unique()) > 1:
# multiple patch rates per block is unexpected, log a note and pick the first rate to move forward
AnalysisNote.insert1(
{
"note_timestamp": datetime.utcnow(),
"note_type": "Multiple patch rates",
"note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}",
}
)

encoder_fs = (
1 / encoder_df.index.to_series().diff().dt.total_seconds().median()
) # mean or median?
wheel_downsampling_factor = int(encoder_fs / final_encoder_fs)
patch_rate = depletion_state_df.rate.iloc[0]
patch_offset = depletion_state_df.offset.iloc[0]
# handles patch rate value being INF
patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it actually an issue if patch rate is inf? Does it cause some downstream issue? We do this as default when no env is loaded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's due to MySQL float doesn't handle INF well - we can convert to NaN but that would lose the intended meaning of INF here

else:
logger.warning(f"No depletion state data found for block {key} - patch: {patch_name}")
patch_rate = None
patch_offset = None

block_patch_entries.append(
{
**key,
"patch_name": patch_name,
"pellet_count": len(pellet_ts_threshold_df),
"pellet_timestamps": pellet_ts_threshold_df.pellet_timestamp.values,
"wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[
::wheel_downsampling_factor
],
"wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor],
"wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values,
"wheel_timestamps": encoder_df.index.values,
"patch_threshold": pellet_ts_threshold_df.threshold.values,
"patch_threshold_timestamps": pellet_ts_threshold_df.index.values,
"patch_rate": patch_rate,
"patch_offset": patch_offset,
}
)

# update block_end if last timestamp of encoder_df is before the current block_end
block_end = min(encoder_df.index[-1], block_end)

# Subject data
# Get all unique subjects that visited the environment over the entire exp;
# For each subject, see 'type' of visit most recent to start of block
Expand All @@ -288,27 +300,50 @@ def make(self, key):
& f'chunk_start <= "{chunk_keys[-1]["chunk_start"]}"'
)[:block_start]
subject_visits_df = subject_visits_df[subject_visits_df.region == "Environment"]
subject_visits_df = subject_visits_df[~subject_visits_df.id.str.contains("Test", case=False)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sometimes we use other, non "Test" subjects as test subjects. Maybe the check should be, if the subject does not begin with 'baa' (can str.lower to check for regardless of case) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Is startswith("baa") reliable, future-proof? Perhaps better logic is to also cross check with the Subjects to be manually specified by users for a particular experiment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, your suggestion would be a better check.

subject_names = []
for subject_name in set(subject_visits_df.id):
_df = subject_visits_df[subject_visits_df.id == subject_name]
if _df.type.iloc[-1] != "Exit":
subject_names.append(subject_name)

if use_blob_position and len(subject_names) > 1:
raise ValueError(
f"Without SLEAPTracking, BlobPosition can only handle single-subject block. Found {len(subject_names)} subjects."
)

block_subject_entries = []
for subject_name in subject_names:
# positions - query for CameraTop, identity_name matches subject_name,
pos_query = (
streams.SpinnakerVideoSource
* tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part")
* tracking.SLEAPTracking.Part
& key
& {
"spinnaker_video_source_name": "CameraTop",
"identity_name": subject_name,
}
& chunk_restriction
)
pos_df = fetch_stream(pos_query)[block_start:block_end]
if use_blob_position:
pos_query = (
streams.SpinnakerVideoSource
* tracking.BlobPosition.Object
& key
& chunk_restriction
& {
"spinnaker_video_source_name": "CameraTop",
"identity_name": subject_name
}
)
pos_df = fetch_stream(pos_query)[block_start:block_end]
pos_df["likelihood"] = np.nan
# keep only rows with area between 0 and 1000 - likely artifacts otherwise
pos_df = pos_df[(pos_df.area > 0) & (pos_df.area < 1000)]
else:
pos_query = (
streams.SpinnakerVideoSource
* tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part")
* tracking.SLEAPTracking.Part
& key
& {
"spinnaker_video_source_name": "CameraTop",
"identity_name": subject_name,
}
& chunk_restriction
)
pos_df = fetch_stream(pos_query)[block_start:block_end]

pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end)

if pos_df.empty:
Expand Down Expand Up @@ -345,8 +380,8 @@ def make(self, key):
{
**key,
"block_duration": (block_end - block_start).total_seconds() / 3600,
"patch_count": len(patch_keys),
"subject_count": len(subject_names),
"patch_count": len(block_patch_entries),
"subject_count": len(block_subject_entries),
}
)
self.Patch.insert(block_patch_entries)
Expand Down Expand Up @@ -423,6 +458,17 @@ def make(self, key):
)
subjects_positions_df.set_index("position_timestamps", inplace=True)

# Ensure wheel_timestamps are of the same length across all patches
wheel_lens = [len(p["wheel_timestamps"]) for p in block_patches]
if len(set(wheel_lens)) > 1:
max_diff = max(wheel_lens) - min(wheel_lens)
if max_diff > 10:
# if diff is more than 10 samples, raise error, this is unexpected, some patches crash?
raise ValueError(f"Wheel data lengths are not consistent across patches ({max_diff} samples diff)")
for p in block_patches:
p["wheel_timestamps"] = p["wheel_timestamps"][: min(wheel_lens)]
p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][: min(wheel_lens)]

self.insert1(key)

in_patch_radius = 130 # pixels
Expand Down Expand Up @@ -541,7 +587,7 @@ def make(self, key):
| {
"patch_name": patch["patch_name"],
"subject_name": subject_name,
"in_patch_timestamps": subject_in_patch.index.values,
"in_patch_timestamps": subject_in_patch[in_patch[subject_name]].index.values,
"in_patch_time": subject_in_patch_cum_time[-1],
"pellet_count": len(subj_pellets),
"pellet_timestamps": subj_pellets.index.values,
Expand Down Expand Up @@ -1521,10 +1567,10 @@ def make(self, key):
foraging_bout_df = get_foraging_bouts(key)
foraging_bout_df.rename(
columns={
"subject_name": "subject",
"bout_start": "start",
"bout_end": "end",
"pellet_count": "n_pellets",
"subject": "subject_name",
"start": "bout_start",
"end": "bout_end",
"n_pellets": "pellet_count",
"cum_wheel_dist": "cum_wheel_dist",
},
inplace=True,
Expand All @@ -1540,7 +1586,7 @@ def make(self, key):
@schema
class AnalysisNote(dj.Manual):
definition = """ # Generic table to catch all notes generated during analysis
note_timestamp: datetime
note_timestamp: datetime(6)
---
note_type='': varchar(64)
note: varchar(3000)
Expand Down
24 changes: 22 additions & 2 deletions aeon/dj_pipeline/populate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,32 @@ def ingest_epochs_chunks():
)

analysis_worker(block_analysis.BlockAnalysis, max_calls=6)
analysis_worker(block_analysis.BlockPlots, max_calls=6)
analysis_worker(block_analysis.BlockSubjectAnalysis, max_calls=6)
analysis_worker(block_analysis.BlockSubjectPlots, max_calls=6)
analysis_worker(block_analysis.BlockForaging, max_calls=6)
analysis_worker(block_analysis.BlockPatchPlots, max_calls=6)
analysis_worker(block_analysis.BlockSubjectPositionPlots, max_calls=6)


def get_workflow_operation_overview():
from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview

return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix])


def retrieve_schemas_sizes(schema_only=False, all_schemas=False):
schema_names = [n for n in dj.list_schemas() if n != "mysql"]
if not all_schemas:
schema_names = [n for n in schema_names
if n.startswith(db_prefix) and not n.startswith(f"{db_prefix}archived")]

if schema_only:
return {n: dj.Schema(n).size_on_disk / 1e9 for n in schema_names}

schema_sizes = {n: {} for n in schema_names}
for n in schema_names:
vm = dj.VirtualModule(n, n)
schema_sizes[n]["schema_gb"] = vm.schema.size_on_disk / 1e9
schema_sizes[n]["tables_gb"] = {n: t().size_on_disk / 1e9
for n, t in vm.__dict__.items()
if isinstance(t, dj.user_tables.TableMeta)}
return schema_sizes
Loading