Skip to content

Commit

Permalink
Remove references to the motion correction drift table (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephen-riggs authored Oct 20, 2023
1 parent 179ba33 commit 1ea53c3
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 56 deletions.
3 changes: 1 addition & 2 deletions src/relion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(
raise ValueError(f"path {self.basepath} is not a directory")
self._data_pipeline = Graph("DataPipeline", [])
self._db_model = DBModel(database)
self._drift_cache = {}
if run_options is None:
self.run_options = RelionItOptions()
else:
Expand Down Expand Up @@ -159,7 +158,7 @@ def motioncorrection(self):
"""access the motion correction stage of the project.
Returns a dictionary-like object with job names as keys,
and lists of MCMicrograph namedtuples as values."""
return MotionCorr(self.basepath / "MotionCorr", self._drift_cache)
return MotionCorr(self.basepath / "MotionCorr")

@property
@functools.lru_cache(maxsize=1)
Expand Down
74 changes: 26 additions & 48 deletions src/relion/_parser/motioncorrection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import logging
from collections import namedtuple
from pathlib import Path

import numpy as np
import plotly.express as px

from relion._parser.jobtype import JobType

Expand All @@ -16,8 +20,9 @@
"total_motion",
"early_motion",
"late_motion",
"average_motion_per_frame",
"micrograph_timestamp",
"drift_data",
"drift_plot_full_path",
],
)

Expand Down Expand Up @@ -57,12 +62,8 @@


class MotionCorr(JobType):
def __init__(self, path, drift_cache=None):
def __init__(self, path):
super().__init__(path)
if drift_cache is None:
self._drift_cache = {}
else:
self._drift_cache = drift_cache

def __eq__(self, other):
if isinstance(other, MotionCorr): # check this
Expand Down Expand Up @@ -102,7 +103,11 @@ def _load_job_directory(self, jobdir):

micrograph_list = []
for j in range(len(micrograph_name)):
drift_data, movie_name = self.collect_drift_data(micrograph_name[j], jobdir)
(
number_of_frames,
drift_plot_full_path,
movie_name,
) = self.collect_drift_data(micrograph_name[j], jobdir)
if movie_name:
try:
movie_creation_time = (
Expand All @@ -125,55 +130,31 @@ def _load_job_directory(self, jobdir):
accum_motion_total[j],
accum_motion_early[j],
accum_motion_late[j],
float(accum_motion_total[j]) / number_of_frames,
movie_creation_time,
drift_data,
drift_plot_full_path,
)
)
return micrograph_list

def collect_drift_data(self, mic_name, jobdir):
drift_data = []
drift_star_file_path = mic_name.split(jobdir + "/")[-1].replace("mrc", "star")
if self._drift_cache.get(jobdir):
if self._drift_cache[jobdir].get(mic_name):
try:
if (
self._drift_cache[jobdir][mic_name].file_size
== (self._basepath / jobdir / drift_star_file_path)
.stat()
.st_size
):
return (
self._drift_cache[jobdir][mic_name].data,
self._drift_cache[jobdir][mic_name].movie_name,
)
except FileNotFoundError:
logger.debug(
"Could not find expected file containing drift data",
exc_info=True,
)
return [], ""
else:
self._drift_cache[jobdir] = {}
try:
drift_star_file = self._read_star_file(jobdir, drift_star_file_path)
except (FileNotFoundError, OSError, RuntimeError, ValueError):
return drift_data, ""
return 1, "", ""
try:
info_table = self._find_table_from_column_name(
"_rlnMicrographFrameNumber", drift_star_file
)
movie_table = 0
except (FileNotFoundError, OSError, RuntimeError, ValueError):
return drift_data, ""
return 1, "", ""
if info_table is None:
logger.debug(
f"_rlnMicrographFrameNumber or _rlnMicrographMovieName not found in file {drift_star_file}"
)
return drift_data, ""
frame_numbers = self.parse_star_file(
"_rlnMicrographFrameNumber", drift_star_file, info_table
)
return 1, "", ""
deltaxs = self.parse_star_file(
"_rlnMicrographShiftX", drift_star_file, info_table
)
Expand All @@ -183,17 +164,16 @@ def collect_drift_data(self, mic_name, jobdir):
movie_name = self.parse_star_file_pair(
"_rlnMicrographMovieName", drift_star_file, movie_table
)
for f, dx, dy in zip(frame_numbers, deltaxs, deltays):
drift_data.append(MCMicrographDrift(int(f), float(dx), float(dy)))
try:
self._drift_cache[jobdir][mic_name] = MCDriftCacheRecord(
drift_data,
(self._basepath / jobdir / drift_star_file_path).stat().st_size,
movie_name,
fig = px.scatter(
x=np.array(deltaxs, dtype=float), y=np.array(deltays, dtype=float)
)
drift_plot_name = Path(mic_name).stem + "_drift_plot.json"
drift_plot_full_path = Path(mic_name).parent / drift_plot_name
fig.write_json(drift_plot_full_path)
except FileNotFoundError:
return [], ""
return drift_data, movie_name
return 1, "", ""
return len(deltaxs), drift_plot_full_path, movie_name

@staticmethod
def for_cache(mcmicrograph):
Expand All @@ -220,12 +200,10 @@ def db_unpack(micrograph_list):
"total_motion": micrograph.total_motion,
"early_motion": micrograph.early_motion,
"late_motion": micrograph.late_motion,
"average_motion_per_frame": (
float(micrograph.total_motion) / len(micrograph.drift_data)
),
"average_motion_per_frame": micrograph.average_motion_per_frame,
"image_number": micrograph.micrograph_number,
"micrograph_snapshot_full_path": micrograph.micrograph_snapshot_full_path,
"drift_data": micrograph.drift_data,
"drift_plot_full_path": micrograph.drift_plot_full_path,
"created_time_stamp": micrograph.micrograph_timestamp,
}
for micrograph in micrograph_list
Expand Down
1 change: 0 additions & 1 deletion src/relion/dbmodel/modeltables.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ class MotionCorrectionTable(Table):
def __init__(self):
columns, prim_key = parse_sqlalchemy_table(sqlalchemy.MotionCorrection)
columns.append("created_time_stamp")
columns.append("drift_data")
# columns.append("job_string")
super().__init__(
columns,
Expand Down
6 changes: 1 addition & 5 deletions src/relion/zocalo/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,18 +579,15 @@ def _(
unsent_appended: Optional[dict] = None,
):
row = table.get_row_by_primary_key(primary_key)
drift_data = row["drift_data"]
buffered = ["motion_correction_id", "drift_data"]
buffered = ["motion_correction_id"]
buffer_store = row["motion_correction_id"]
drift_frames = [(frame.frame, frame.deltaX, frame.deltaY) for frame in drift_data]
if resend:
results = {
"ispyb_command": "buffer",
"buffer_lookup": {"motion_correction_id": buffer_store},
"buffer_command": {
"ispyb_command": "insert_motion_correction",
**{k: v for k, v in row.items() if k not in buffered},
"drift_frames": drift_frames,
},
}
else:
Expand All @@ -599,7 +596,6 @@ def _(
"buffer_command": {
"ispyb_command": "insert_motion_correction",
**{k: v for k, v in row.items() if k not in buffered},
"drift_frames": drift_frames,
},
"buffer_store": buffer_store,
}
Expand Down

0 comments on commit 1ea53c3

Please sign in to comment.