diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 2a233d8b..db3934a8 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -42,10 +42,18 @@ class Block(dj.Manual): @schema class BlockDetection(dj.Computed): - definition = """ + definition = """ # Detecting new block(s) for each new Chunk -> acquisition.Environment + --- + execution_time=null: datetime """ + class IdentifiedBlock(dj.Part): + definition = """ # the block(s) identified in this BlockDetection + -> master + -> Block + """ + key_source = acquisition.Environment - {"experiment_name": "social0.1-aeon3"} def make(self, key): @@ -70,12 +78,9 @@ def make(self, key): block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction block_state_df = fetch_stream(block_state_query) if block_state_df.empty: - self.insert1(key) + # self.insert1(key) return - block_state_df.index = block_state_df.index.round( - "us" - ) # timestamp precision in DJ is only at microseconds block_state_df = block_state_df.loc[ (block_state_df.index > chunk_start) & (block_state_df.index <= chunk_end) ] @@ -103,7 +108,10 @@ def make(self, key): ) Block.insert(block_entries, skip_duplicates=True) - self.insert1(key) + # self.insert1({**key, "execution_time": datetime.now(UTC)}) + self.IdentifiedBlock.insert( + {**key, "block_start": entry["block_start"]} for entry in block_entries + ) # ---- Block Analysis and Visualization ---- @@ -316,6 +324,15 @@ def make(self, key): if _df.type.iloc[-1] != "Exit": subject_names.append(subject_name) + # Check for ExperimentTimeline to validate subjects in this block + timeline_query = (acquisition.ExperimentTimeline + & acquisition.ExperimentTimeline.Subject + & key + & f"start <= '{block_start}' AND end >= '{block_end}'") + timeline_subjects = (acquisition.ExperimentTimeline.Subject & timeline_query).fetch("subject") + if len(timeline_subjects): + subject_names = [s for s in subject_names if s in timeline_subjects] + if use_blob_position and len(subject_names) > 1: raise ValueError( f"Without SLEAPTracking, BlobPosition can only handle a single-subject block. " @@ -1331,12 +1348,21 @@ class BlockSubjectPositionPlots(dj.Computed): definition = """ -> BlockSubjectAnalysis --- - ethogram_data: longblob # ethogram data in record array format position_plot: longblob # position plot (plotly) position_heatmap_plot: longblob # position heatmap plot (plotly) position_ethogram_plot: longblob # position ethogram plot (plotly) """ + class InROI(dj.Part): + definition = """ # time spent in each ROI for each subject + -> master + subject_name: varchar(32) + roi_name: varchar(32) + --- + in_roi_time: float # total seconds spent in this ROI for this block + in_roi_timestamps: longblob # timestamps when a subject is at a specific ROI + """ + def make(self, key): """Compute and plot various block-level statistics and visualizations.""" # Get some block info @@ -1353,32 +1379,26 @@ def make(self, key): # Figure 1 - Position (centroid) over time # --- # Get animal position data - pose_query = ( - streams.SpinnakerVideoSource - * tracking.SLEAPTracking.PoseIdentity.proj( - "identity_name", "identity_likelihood", part_name="anchor_part" - ) - * tracking.SLEAPTracking.Part - & {"spinnaker_video_source_name": "CameraTop"} - & key - & chunk_restriction - ) - centroid_df = fetch_stream(pose_query)[block_start:block_end] + pos_cols = {"x": "position_x", "y": "position_y", "time": "position_timestamps"} + centroid_df = (BlockAnalysis.Subject.proj(**pos_cols) + & key).fetch(format="frame").reset_index() + centroid_df.drop(columns=["experiment_name", "block_start"], inplace=True) + centroid_df = centroid_df.explode(column=list(pos_cols)) + centroid_df.set_index("time", inplace=True) centroid_df = ( - centroid_df.groupby("identity_name") + centroid_df.groupby("subject_name") .resample("100ms") .first() - .droplevel("identity_name") + .droplevel("subject_name") .dropna() .sort_index() ) - centroid_df.drop(columns=["spinnaker_video_source_name"], inplace=True) centroid_df["x"] = centroid_df["x"].astype(np.int32) centroid_df["y"] = centroid_df["y"].astype(np.int32) # Plot it position_fig = go.Figure() - for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("identity_name")): + for id_i, (id_val, id_grp) in enumerate(centroid_df.groupby("subject_name")): norm_time = ( (id_grp.index - id_grp.index[0]) / (id_grp.index[-1] - id_grp.index[0]) ).values.round(3) @@ -1407,7 +1427,7 @@ def make(self, key): # Calculate heatmaps max_x, max_y = int(centroid_df["x"].max()), int(centroid_df["y"].max()) heatmaps = [] - for id_val, id_grp in centroid_df.groupby("identity_name"): + for id_val, id_grp in centroid_df.groupby("subject_name"): # Add counts of x,y points to a grid that will be used for heatmap img_grid = np.zeros((max_x + 1, max_y + 1)) points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0) @@ -1483,7 +1503,7 @@ def make(self, key): pos_eth_df = pd.DataFrame( columns=(["Subject"] + rois), index=centroid_df.index ) # df to create eth fig - pos_eth_df["Subject"] = centroid_df["identity_name"] + pos_eth_df["Subject"] = centroid_df["subject_name"] # For each ROI, compute if within ROI for roi in rois: @@ -1558,10 +1578,23 @@ def make(self, key): ): entry[fig_name] = json.loads(fig.to_json()) - melted_df.drop(columns=["Val"], inplace=True) - entry["ethogram_data"] = melted_df.to_records(index=False) + # insert into InROI + in_roi_entries = [] + for subject_name, roi_name in itertools.product(set(melted_df.Subject), rois): + df_ = melted_df[(melted_df["Subject"] == subject_name) & (melted_df["Loc"] == roi_name)] + + roi_timestamps = df_["time"].values + roi_time = len(roi_timestamps) * 0.1 # 100ms per timestamp + + in_roi_entries.append( + {**key, "subject_name": subject_name, + "roi_name": roi_name, + "in_roi_time": roi_time, + "in_roi_timestamps": roi_timestamps} + ) self.insert1(entry) + self.InROI.insert(in_roi_entries) # ---- Foraging Bout Analysis ---- diff --git a/aeon/dj_pipeline/scripts/fix_anchor_part_fullpose.py b/aeon/dj_pipeline/scripts/fix_anchor_part_fullpose.py new file mode 100644 index 00000000..d7429ba1 --- /dev/null +++ b/aeon/dj_pipeline/scripts/fix_anchor_part_fullpose.py @@ -0,0 +1,64 @@ +""" +Script to fix the anchor part of the fullpose SLEAP entries. +See this commit: https://github.com/SainsburyWellcomeCentre/aeon_mecha/commit/8358ce4b6923918920efb77d09adc769721dbb9b + +Last run: --- +""" +import pandas as pd +from tqdm import tqdm +from aeon.dj_pipeline import acquisition, tracking, streams + +aeon_schemas = acquisition.aeon_schemas +logger = acquisition.logger +io_api = acquisition.io_api + + +def update_anchor_part(key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + + stream_reader = getattr(getattr(devices_schema, device_name), "Pose") + + # special ingestion case for social0.2 full-pose data (using Pose reader from social03) + # fullpose for social0.2 has a different "pattern" for non-fullpose, hence the Pose03 reader + if key["experiment_name"].startswith("social0.2"): + from aeon.io import reader as io_reader + stream_reader = getattr(getattr(devices_schema, device_name), "Pose03") + assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader" + data_dirs = [acquisition.Experiment.get_data_directory(key, "processed")] + + pose_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + if not len(pose_data): + raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}") + + # get anchor part + anchor_part = next(v.replace("_x", "") for v in stream_reader.columns if v.endswith("_x")) + + # update anchor part + for entry in tracking.SLEAPTracking.PoseIdentity.fetch("KEY"): + entry["anchor_part"] = anchor_part + tracking.SLEAPTracking.PoseIdentity.update1(entry) + + logger.info(f"Anchor part updated to {anchor_part} for {key}") + + +def main(): + keys = tracking.SLEAPTracking.fetch("KEY") + for key in tqdm(keys): + update_anchor_part(key) diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 3ff95770..52a524a4 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -372,6 +372,9 @@ def make(self, key): "lab_id": animal_resp["labid"], } ) + + associate_subject_and_experiment(eartag_or_id) + completion_time = datetime.now(UTC) self.insert1( { @@ -499,3 +502,28 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): ) return response.json() + + +def associate_subject_and_experiment(subject_name): + """ + Check SubjectComment for experiment name for which the animal is participating in. + The expected comment format is "experiment: ". + E.g. "experiment: social0.3-aeon3" + Note: this function many need to run repeatedly to catch all experiments/animals. + - an experiment is not yet added when the animal comment is added + - the animal comment is not yet added when the experiment is created + """ + from aeon.dj_pipeline import acquisition + + new_entries = [] + for entry in (SubjectComment.proj("content") + & {"subject": subject_name} + & "content LIKE 'experiment:%'").fetch(as_dict=True): + entry.pop("comment_id") + entry["experiment_name"] = entry.pop("content").replace("experiment:", "").strip() + if acquisition.Experiment.proj() & entry: + if not acquisition.Experiment.Subject & entry: + new_entries.append(entry) + logger.info(f"\tNew experiment subject: {entry}") + + acquisition.Experiment.Subject.insert(new_entries) diff --git a/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml b/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml index 7c8793ec..20052ca9 100644 --- a/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml +++ b/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml @@ -669,8 +669,14 @@ SciViz: return dict(**kwargs) dj_query: > def dj_query(aeon_block_analysis): - aeon_analysis = aeon_block_analysis - query = aeon_analysis.Block * aeon_analysis.BlockAnalysis + block_analysis = aeon_block_analysis + query = block_analysis.Block.proj() * block_analysis.BlockAnalysis + query *= block_analysis.BlockAnalysis.aggr( + block_analysis.BlockAnalysis.Subject, subjects="GROUP_CONCAT(subject_name)", keep_all_rows=True) + query *= block_analysis.BlockAnalysis.aggr( + block_analysis.BlockAnalysis.Patch.proj(patch_rate="CONCAT(patch_name, ':', patch_rate, '(', patch_offset, ')')"), patch_rates="GROUP_CONCAT(patch_rate)", keep_all_rows=True) + query *= block_analysis.BlockAnalysis.aggr( + block_analysis.BlockAnalysis.Patch.proj(patch_pellet="CONCAT(patch_name, ':', pellet_count)"), patch_pellets="GROUP_CONCAT(patch_pellet)", keep_all_rows=True) return dict(query=query, fetch_args=[]) comp2: route: /per_block_patch_stats_plot