Skip to content

Commit

Permalink
Merge pull request #451 from ttngu207/dev_queriable_in_roi_time
Browse files Browse the repository at this point in the history
Table to compute/store InROI times AND subject-experiment mapping from Pyrat
  • Loading branch information
jkbhagatio authored Nov 21, 2024
2 parents 9be1f8e + fa53e23 commit c6dd2ba
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 28 deletions.
85 changes: 59 additions & 26 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ class Block(dj.Manual):

@schema
class BlockDetection(dj.Computed):
definition = """
definition = """ # Detecting new block(s) for each new Chunk
-> acquisition.Environment
---
execution_time=null: datetime
"""

class IdentifiedBlock(dj.Part):
definition = """ # the block(s) identified in this BlockDetection
-> master
-> Block
"""

key_source = acquisition.Environment - {"experiment_name": "social0.1-aeon3"}

def make(self, key):
Expand All @@ -70,12 +78,9 @@ def make(self, key):
block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction
block_state_df = fetch_stream(block_state_query)
if block_state_df.empty:
self.insert1(key)
# self.insert1(key)
return

block_state_df.index = block_state_df.index.round(
"us"
) # timestamp precision in DJ is only at microseconds
block_state_df = block_state_df.loc[
(block_state_df.index > chunk_start) & (block_state_df.index <= chunk_end)
]
Expand Down Expand Up @@ -103,7 +108,10 @@ def make(self, key):
)

Block.insert(block_entries, skip_duplicates=True)
self.insert1(key)
# self.insert1({**key, "execution_time": datetime.now(UTC)})
self.IdentifiedBlock.insert(
{**key, "block_start": entry["block_start"]} for entry in block_entries
)


# ---- Block Analysis and Visualization ----
Expand Down Expand Up @@ -316,6 +324,15 @@ def make(self, key):
if _df.type.iloc[-1] != "Exit":
subject_names.append(subject_name)

# Check for ExperimentTimeline to validate subjects in this block
timeline_query = (acquisition.ExperimentTimeline
& acquisition.ExperimentTimeline.Subject
& key
& f"start <= '{block_start}' AND end >= '{block_end}'")
timeline_subjects = (acquisition.ExperimentTimeline.Subject & timeline_query).fetch("subject")
if len(timeline_subjects):
subject_names = [s for s in subject_names if s in timeline_subjects]

if use_blob_position and len(subject_names) > 1:
raise ValueError(
f"Without SLEAPTracking, BlobPosition can only handle a single-subject block. "
Expand Down Expand Up @@ -1331,12 +1348,21 @@ class BlockSubjectPositionPlots(dj.Computed):
definition = """
-> BlockSubjectAnalysis
---
ethogram_data: longblob # ethogram data in record array format
position_plot: longblob # position plot (plotly)
position_heatmap_plot: longblob # position heatmap plot (plotly)
position_ethogram_plot: longblob # position ethogram plot (plotly)
"""

class InROI(dj.Part):
definition = """ # time spent in each ROI for each subject
-> master
subject_name: varchar(32)
roi_name: varchar(32)
---
in_roi_time: float # total seconds spent in this ROI for this block
in_roi_timestamps: longblob # timestamps when a subject is at a specific ROI
"""

def make(self, key):
"""Compute and plot various block-level statistics and visualizations."""
# Get some block info
Expand All @@ -1353,32 +1379,26 @@ def make(self, key):
# Figure 1 - Position (centroid) over time
# ---
# Get animal position data
pose_query = (
streams.SpinnakerVideoSource
* tracking.SLEAPTracking.PoseIdentity.proj(
"identity_name", "identity_likelihood", part_name="anchor_part"
)
* tracking.SLEAPTracking.Part
& {"spinnaker_video_source_name": "CameraTop"}
& key
& chunk_restriction
)
centroid_df = fetch_stream(pose_query)[block_start:block_end]
pos_cols = {"x": "position_x", "y": "position_y", "time": "position_timestamps"}
centroid_df = (BlockAnalysis.Subject.proj(**pos_cols)
& key).fetch(format="frame").reset_index()
centroid_df.drop(columns=["experiment_name", "block_start"], inplace=True)
centroid_df = centroid_df.explode(column=list(pos_cols))
centroid_df.set_index("time", inplace=True)
centroid_df = (
centroid_df.groupby("identity_name")
centroid_df.groupby("subject_name")
.resample("100ms")
.first()
.droplevel("identity_name")
.droplevel("subject_name")
.dropna()
.sort_index()
)
centroid_df.drop(columns=["spinnaker_video_source_name"], inplace=True)
centroid_df["x"] = centroid_df["x"].astype(np.int32)
centroid_df["y"] = centroid_df["y"].astype(np.int32)

# Plot it
position_fig = go.Figure()
for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("identity_name")):
for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("subject_name")):
norm_time = (
(id_grp.index - id_grp.index[0]) / (id_grp.index[-1] - id_grp.index[0])
).values.round(3)
Expand Down Expand Up @@ -1407,7 +1427,7 @@ def make(self, key):
# Calculate heatmaps
max_x, max_y = int(centroid_df["x"].max()), int(centroid_df["y"].max())
heatmaps = []
for id_val, id_grp in centroid_df.groupby("identity_name"):
for id_val, id_grp in centroid_df.groupby("subject_name"):
# Add counts of x,y points to a grid that will be used for heatmap
img_grid = np.zeros((max_x + 1, max_y + 1))
points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0)
Expand Down Expand Up @@ -1483,7 +1503,7 @@ def make(self, key):
pos_eth_df = pd.DataFrame(
columns=(["Subject"] + rois), index=centroid_df.index
) # df to create eth fig
pos_eth_df["Subject"] = centroid_df["identity_name"]
pos_eth_df["Subject"] = centroid_df["subject_name"]

# For each ROI, compute if within ROI
for roi in rois:
Expand Down Expand Up @@ -1558,10 +1578,23 @@ def make(self, key):
):
entry[fig_name] = json.loads(fig.to_json())

melted_df.drop(columns=["Val"], inplace=True)
entry["ethogram_data"] = melted_df.to_records(index=False)
# insert into InROI
in_roi_entries = []
for subject_name, roi_name in itertools.product(set(melted_df.Subject), rois):
df_ = melted_df[(melted_df["Subject"] == subject_name) & (melted_df["Loc"] == roi_name)]

roi_timestamps = df_["time"].values
roi_time = len(roi_timestamps) * 0.1 # 100ms per timestamp

in_roi_entries.append(
{**key, "subject_name": subject_name,
"roi_name": roi_name,
"in_roi_time": roi_time,
"in_roi_timestamps": roi_timestamps}
)

self.insert1(entry)
self.InROI.insert(in_roi_entries)


# ---- Foraging Bout Analysis ----
Expand Down
64 changes: 64 additions & 0 deletions aeon/dj_pipeline/scripts/fix_anchor_part_fullpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Script to fix the anchor part of the fullpose SLEAP entries.
See this commit: https://github.com/SainsburyWellcomeCentre/aeon_mecha/commit/8358ce4b6923918920efb77d09adc769721dbb9b
Last run: ---
"""
import pandas as pd
from tqdm import tqdm
from aeon.dj_pipeline import acquisition, tracking, streams

aeon_schemas = acquisition.aeon_schemas
logger = acquisition.logger
io_api = acquisition.io_api


def update_anchor_part(key):
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end")

data_dirs = acquisition.Experiment.get_data_directories(key)

device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name")

devices_schema = getattr(
aeon_schemas,
(acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1(
"devices_schema_name"
),
)

stream_reader = getattr(getattr(devices_schema, device_name), "Pose")

# special ingestion case for social0.2 full-pose data (using Pose reader from social03)
# fullpose for social0.2 has a different "pattern" for non-fullpose, hence the Pose03 reader
if key["experiment_name"].startswith("social0.2"):
from aeon.io import reader as io_reader
stream_reader = getattr(getattr(devices_schema, device_name), "Pose03")
assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader"
data_dirs = [acquisition.Experiment.get_data_directory(key, "processed")]

pose_data = io_api.load(
root=data_dirs,
reader=stream_reader,
start=pd.Timestamp(chunk_start),
end=pd.Timestamp(chunk_end),
)

if not len(pose_data):
raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}")

# get anchor part
anchor_part = next(v.replace("_x", "") for v in stream_reader.columns if v.endswith("_x"))

# update anchor part
for entry in tracking.SLEAPTracking.PoseIdentity.fetch("KEY"):
entry["anchor_part"] = anchor_part
tracking.SLEAPTracking.PoseIdentity.update1(entry)

logger.info(f"Anchor part updated to {anchor_part} for {key}")


def main():
keys = tracking.SLEAPTracking.fetch("KEY")
for key in tqdm(keys):
update_anchor_part(key)
28 changes: 28 additions & 0 deletions aeon/dj_pipeline/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ def make(self, key):
"lab_id": animal_resp["labid"],
}
)

associate_subject_and_experiment(eartag_or_id)

completion_time = datetime.now(UTC)
self.insert1(
{
Expand Down Expand Up @@ -499,3 +502,28 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs):
)

return response.json()


def associate_subject_and_experiment(subject_name):
"""
Check SubjectComment for experiment name for which the animal is participating in.
The expected comment format is "experiment: <experiment_name>".
E.g. "experiment: social0.3-aeon3"
Note: this function many need to run repeatedly to catch all experiments/animals.
- an experiment is not yet added when the animal comment is added
- the animal comment is not yet added when the experiment is created
"""
from aeon.dj_pipeline import acquisition

new_entries = []
for entry in (SubjectComment.proj("content")
& {"subject": subject_name}
& "content LIKE 'experiment:%'").fetch(as_dict=True):
entry.pop("comment_id")
entry["experiment_name"] = entry.pop("content").replace("experiment:", "").strip()
if acquisition.Experiment.proj() & entry:
if not acquisition.Experiment.Subject & entry:
new_entries.append(entry)
logger.info(f"\tNew experiment subject: {entry}")

acquisition.Experiment.Subject.insert(new_entries)
10 changes: 8 additions & 2 deletions aeon/dj_pipeline/webapps/sciviz/specsheet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,14 @@ SciViz:
return dict(**kwargs)
dj_query: >
def dj_query(aeon_block_analysis):
aeon_analysis = aeon_block_analysis
query = aeon_analysis.Block * aeon_analysis.BlockAnalysis
block_analysis = aeon_block_analysis
query = block_analysis.Block.proj() * block_analysis.BlockAnalysis
query *= block_analysis.BlockAnalysis.aggr(
block_analysis.BlockAnalysis.Subject, subjects="GROUP_CONCAT(subject_name)", keep_all_rows=True)
query *= block_analysis.BlockAnalysis.aggr(
block_analysis.BlockAnalysis.Patch.proj(patch_rate="CONCAT(patch_name, ':', patch_rate, '(', patch_offset, ')')"), patch_rates="GROUP_CONCAT(patch_rate)", keep_all_rows=True)
query *= block_analysis.BlockAnalysis.aggr(
block_analysis.BlockAnalysis.Patch.proj(patch_pellet="CONCAT(patch_name, ':', pellet_count)"), patch_pellets="GROUP_CONCAT(patch_pellet)", keep_all_rows=True)
return dict(query=query, fetch_args=[])
comp2:
route: /per_block_patch_stats_plot
Expand Down

0 comments on commit c6dd2ba

Please sign in to comment.