From 2cff62babd7816c14fef9246152b53a2bb59d991 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Nov 2024 14:36:53 +0100 Subject: [PATCH 1/2] plot drift with the scatter plot --- .../benchmark/benchmark_motion_estimation.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py index abb2a51bae..3a7d11fc35 100644 --- a/src/spikeinterface/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -109,6 +109,9 @@ def run(self, **job_kwargs): estimate_motion=t4 - t3, ) + + self.result["peaks"] = peaks + self.result["peak_locations"] = peak_locations self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion @@ -131,6 +134,8 @@ def compute_result(self, **result_params): self.result["motion"] = motion _run_key_saved = [ + ("peaks", "npy"), + ("peak_locations", "npy"), ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] @@ -161,7 +166,7 @@ def create_benchmark(self, key): def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) - def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)): + def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6)): import matplotlib.pyplot as plt if case_keys is None: @@ -195,6 +200,13 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + if raster: + peaks = bench.result["peaks"] + peak_locations = bench.result["peak_locations"] + rec = bench.recording + x = peaks["sample_index"] / rec.sampling_frequency + y = peak_locations[bench.direction] + ax.scatter(x, y, alpha=.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i] From 22882ef66a8389fdfd7aac30ea4633151f1cdd16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:38:09 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../benchmark/benchmark_motion_estimation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/benchmark/benchmark_motion_estimation.py index 3a7d11fc35..5a3c490d38 100644 --- a/src/spikeinterface/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/benchmark/benchmark_motion_estimation.py @@ -109,7 +109,6 @@ def run(self, **job_kwargs): estimate_motion=t4 - t3, ) - self.result["peaks"] = peaks self.result["peak_locations"] = peak_locations self.result["step_run_times"] = step_run_times @@ -166,7 +165,9 @@ def create_benchmark(self, key): def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) - def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6)): + def plot_drift( + self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6) + ): import matplotlib.pyplot as plt if case_keys is None: @@ -206,7 +207,7 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, raster=Fa rec = bench.recording x = peaks["sample_index"] / rec.sampling_frequency y = peak_locations[bench.direction] - ax.scatter(x, y, alpha=.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") + ax.scatter(x, y, alpha=0.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno") for i in range(gt_motion.displacement[0].shape[1]): depth = motion.spatial_bins_um[i]