Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Table to compute/store InROI times AND subject-experiment mapping from Pyrat #451

85 changes: 59 additions & 26 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ class Block(dj.Manual):

@schema
class BlockDetection(dj.Computed):
definition = """
definition = """ # Detecting new block(s) for each new Chunk
-> acquisition.Environment
---
execution_time=null: datetime
"""

class IdentifiedBlock(dj.Part):
definition = """ # the block(s) identified in this BlockDetection
-> master
-> Block
"""

key_source = acquisition.Environment - {"experiment_name": "social0.1-aeon3"}

def make(self, key):
Expand All @@ -70,12 +78,9 @@ def make(self, key):
block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction
block_state_df = fetch_stream(block_state_query)
if block_state_df.empty:
self.insert1(key)
# self.insert1(key)
return

block_state_df.index = block_state_df.index.round(
"us"
) # timestamp precision in DJ is only at microseconds
block_state_df = block_state_df.loc[
(block_state_df.index > chunk_start) & (block_state_df.index <= chunk_end)
]
Expand Down Expand Up @@ -103,7 +108,10 @@ def make(self, key):
)

Block.insert(block_entries, skip_duplicates=True)
self.insert1(key)
# self.insert1({**key, "execution_time": datetime.now(UTC)})
self.IdentifiedBlock.insert(
{**key, "block_start": entry["block_start"]} for entry in block_entries
)


# ---- Block Analysis and Visualization ----
Expand Down Expand Up @@ -316,6 +324,15 @@ def make(self, key):
if _df.type.iloc[-1] != "Exit":
subject_names.append(subject_name)

# Check for ExperimentTimeline to validate subjects in this block
timeline_query = (acquisition.ExperimentTimeline
& acquisition.ExperimentTimeline.Subject
& key
& f"start <= '{block_start}' AND end >= '{block_end}'")
timeline_subjects = (acquisition.ExperimentTimeline.Subject & timeline_query).fetch("subject")
if len(timeline_subjects):
subject_names = [s for s in subject_names if s in timeline_subjects]

if use_blob_position and len(subject_names) > 1:
raise ValueError(
f"Without SLEAPTracking, BlobPosition can only handle a single-subject block. "
Expand Down Expand Up @@ -1331,12 +1348,21 @@ class BlockSubjectPositionPlots(dj.Computed):
definition = """
-> BlockSubjectAnalysis
---
ethogram_data: longblob # ethogram data in record array format
position_plot: longblob # position plot (plotly)
position_heatmap_plot: longblob # position heatmap plot (plotly)
position_ethogram_plot: longblob # position ethogram plot (plotly)
"""

class InROI(dj.Part):
definition = """ # time spent in each ROI for each subject
-> master
subject_name: varchar(32)
roi_name: varchar(32)
---
in_roi_time: float # total seconds spent in this ROI for this block
in_roi_timestamps: longblob # timestamps when a subject is at a specific ROI
"""

def make(self, key):
"""Compute and plot various block-level statistics and visualizations."""
# Get some block info
Expand All @@ -1353,32 +1379,26 @@ def make(self, key):
# Figure 1 - Position (centroid) over time
# ---
# Get animal position data
pose_query = (
streams.SpinnakerVideoSource
* tracking.SLEAPTracking.PoseIdentity.proj(
"identity_name", "identity_likelihood", part_name="anchor_part"
)
* tracking.SLEAPTracking.Part
& {"spinnaker_video_source_name": "CameraTop"}
& key
& chunk_restriction
)
centroid_df = fetch_stream(pose_query)[block_start:block_end]
pos_cols = {"x": "position_x", "y": "position_y", "time": "position_timestamps"}
centroid_df = (BlockAnalysis.Subject.proj(**pos_cols)
& key).fetch(format="frame").reset_index()
centroid_df.drop(columns=["experiment_name", "block_start"], inplace=True)
centroid_df = centroid_df.explode(column=list(pos_cols))
centroid_df.set_index("time", inplace=True)
centroid_df = (
centroid_df.groupby("identity_name")
centroid_df.groupby("subject_name")
.resample("100ms")
.first()
.droplevel("identity_name")
.droplevel("subject_name")
.dropna()
.sort_index()
)
centroid_df.drop(columns=["spinnaker_video_source_name"], inplace=True)
centroid_df["x"] = centroid_df["x"].astype(np.int32)
centroid_df["y"] = centroid_df["y"].astype(np.int32)

# Plot it
position_fig = go.Figure()
for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("identity_name")):
for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("subject_name")):
norm_time = (
(id_grp.index - id_grp.index[0]) / (id_grp.index[-1] - id_grp.index[0])
).values.round(3)
Expand Down Expand Up @@ -1407,7 +1427,7 @@ def make(self, key):
# Calculate heatmaps
max_x, max_y = int(centroid_df["x"].max()), int(centroid_df["y"].max())
heatmaps = []
for id_val, id_grp in centroid_df.groupby("identity_name"):
for id_val, id_grp in centroid_df.groupby("subject_name"):
# Add counts of x,y points to a grid that will be used for heatmap
img_grid = np.zeros((max_x + 1, max_y + 1))
points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0)
Expand Down Expand Up @@ -1483,7 +1503,7 @@ def make(self, key):
pos_eth_df = pd.DataFrame(
columns=(["Subject"] + rois), index=centroid_df.index
) # df to create eth fig
pos_eth_df["Subject"] = centroid_df["identity_name"]
pos_eth_df["Subject"] = centroid_df["subject_name"]

# For each ROI, compute if within ROI
for roi in rois:
Expand Down Expand Up @@ -1558,10 +1578,23 @@ def make(self, key):
):
entry[fig_name] = json.loads(fig.to_json())

melted_df.drop(columns=["Val"], inplace=True)
entry["ethogram_data"] = melted_df.to_records(index=False)
# insert into InROI
in_roi_entries = []
for subject_name, roi_name in itertools.product(set(melted_df.Subject), rois):
df_ = melted_df[(melted_df["Subject"] == subject_name) & (melted_df["Loc"] == roi_name)]

roi_timestamps = df_["time"].values
roi_time = len(roi_timestamps) * 0.1 # 100ms per timestamp

in_roi_entries.append(
{**key, "subject_name": subject_name,
"roi_name": roi_name,
"in_roi_time": roi_time,
"in_roi_timestamps": roi_timestamps}
)

self.insert1(entry)
self.InROI.insert(in_roi_entries)


# ---- Foraging Bout Analysis ----
Expand Down
64 changes: 64 additions & 0 deletions aeon/dj_pipeline/scripts/fix_anchor_part_fullpose.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ttngu207 sorry I'm missing something obvious, but why is fix_anchor_part_fullpose necessary? I thought anchor part for full pose would be fixed the same way was done for identity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the SLEAP data were ingested before this commit
Thus, those ingested data may have the anchor_part incorrectly extracted/stored in the database.

There are only a very small number, but I don't know which ones, so this script is to correct for that.
I haven't run it yet, just have it here for the record. We may not need to run this at all if we are to delete and ingest new SLEAP data from the composite camera anyway.

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Script to fix the anchor part of the fullpose SLEAP entries.
See this commit: https://github.com/SainsburyWellcomeCentre/aeon_mecha/commit/8358ce4b6923918920efb77d09adc769721dbb9b

Last run: ---
"""
import pandas as pd
from tqdm import tqdm
from aeon.dj_pipeline import acquisition, tracking, streams

aeon_schemas = acquisition.aeon_schemas
logger = acquisition.logger
io_api = acquisition.io_api


def update_anchor_part(key):
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end")

data_dirs = acquisition.Experiment.get_data_directories(key)

device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name")

devices_schema = getattr(
aeon_schemas,
(acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1(
"devices_schema_name"
),
)

stream_reader = getattr(getattr(devices_schema, device_name), "Pose")

# special ingestion case for social0.2 full-pose data (using Pose reader from social03)
# fullpose for social0.2 has a different "pattern" for non-fullpose, hence the Pose03 reader
if key["experiment_name"].startswith("social0.2"):
from aeon.io import reader as io_reader
stream_reader = getattr(getattr(devices_schema, device_name), "Pose03")
assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader"
data_dirs = [acquisition.Experiment.get_data_directory(key, "processed")]

pose_data = io_api.load(
root=data_dirs,
reader=stream_reader,
start=pd.Timestamp(chunk_start),
end=pd.Timestamp(chunk_end),
)

if not len(pose_data):
raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}")

# get anchor part
anchor_part = next(v.replace("_x", "") for v in stream_reader.columns if v.endswith("_x"))

# update anchor part
for entry in tracking.SLEAPTracking.PoseIdentity.fetch("KEY"):
entry["anchor_part"] = anchor_part
tracking.SLEAPTracking.PoseIdentity.update1(entry)

logger.info(f"Anchor part updated to {anchor_part} for {key}")


def main():
keys = tracking.SLEAPTracking.fetch("KEY")
for key in tqdm(keys):
update_anchor_part(key)
28 changes: 28 additions & 0 deletions aeon/dj_pipeline/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def make(self, key):
"lab_id": animal_resp["labid"],
}
)

associate_subject_and_experiment(eartag_or_id)

completion_time = datetime.now(UTC)
self.insert1(
{
Expand Down Expand Up @@ -499,3 +502,28 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs):
)

return response.json()


def associate_subject_and_experiment(subject_name):
"""
Check SubjectComment for experiment name for which the animal is participating in.
The expected comment format is "experiment: <experiment_name>".
E.g. "experiment: social0.3-aeon3"
Note: this function many need to run repeatedly to catch all experiments/animals.
- an experiment is not yet added when the animal comment is added
- the animal comment is not yet added when the experiment is created
"""
from aeon.dj_pipeline import acquisition

new_entries = []
for entry in (SubjectComment.proj("content")
& {"subject": subject_name}
& "content LIKE 'experiment:%'").fetch(as_dict=True):
entry.pop("comment_id")
entry["experiment_name"] = entry.pop("content").replace("experiment:", "").strip()
if acquisition.Experiment.proj() & entry:
if not acquisition.Experiment.Subject & entry:
new_entries.append(entry)
logger.info(f"\tNew experiment subject: {entry}")

acquisition.Experiment.Subject.insert(new_entries)
10 changes: 8 additions & 2 deletions aeon/dj_pipeline/webapps/sciviz/specsheet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,14 @@ SciViz:
return dict(**kwargs)
dj_query: >
def dj_query(aeon_block_analysis):
aeon_analysis = aeon_block_analysis
query = aeon_analysis.Block * aeon_analysis.BlockAnalysis
block_analysis = aeon_block_analysis
query = block_analysis.Block.proj() * block_analysis.BlockAnalysis
query *= block_analysis.BlockAnalysis.aggr(
block_analysis.BlockAnalysis.Subject, subjects="GROUP_CONCAT(subject_name)", keep_all_rows=True)
query *= block_analysis.BlockAnalysis.aggr(
block_analysis.BlockAnalysis.Patch.proj(patch_rate="CONCAT(patch_name, ':', patch_rate, '(', patch_offset, ')')"), patch_rates="GROUP_CONCAT(patch_rate)", keep_all_rows=True)
query *= block_analysis.BlockAnalysis.aggr(
block_analysis.BlockAnalysis.Patch.proj(patch_pellet="CONCAT(patch_name, ':', pellet_count)"), patch_pellets="GROUP_CONCAT(patch_pellet)", keep_all_rows=True)
return dict(query=query, fetch_args=[])
comp2:
route: /per_block_patch_stats_plot
Expand Down