Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 19, 2024
1 parent 8e5ba68 commit 62407b1
Showing 1 changed file with 19 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def _get_single_session_activity_histogram(
bin_s=None,
depth_smooth_um=None,
scale_to_hz=False,
weight_with_amplitude=weight_with_amplitude
weight_with_amplitude=weight_with_amplitude,
)

# It is important that the passed histogram is scaled to firing rate in Hz
Expand Down Expand Up @@ -848,7 +848,7 @@ def get_shifts(signal1, signal2, windows):
signa11_blanked[:first_idx] = 0
signal2_blanked[:first_idx] = 0

if (last_idx := windows[-1][-1]) != signal1.size - 1: #double check
if (last_idx := windows[-1][-1]) != signal1.size - 1: # double check
print("last idx", last_idx)
signa11_blanked[last_idx:] = 0
signal2_blanked[last_idx:] = 0
Expand All @@ -870,14 +870,17 @@ def get_shifts(signal1, signal2, windows):
# shift the signal1, or use indexing
signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, shift)

# plt.plot(signa11_blanked)
# plt.plot(signal2_blanked)
# plt.show()
# plt.plot(signa11_blanked)
# plt.plot(signal2_blanked)
# plt.show()

window_corrs = np.empty(windows.shape[0])
for i, idx in enumerate(windows):

window_corrs[i] = np.correlate(signa11_blanked[idx] - np.mean(signa11_blanked[idx]), signal2_blanked[idx] - np.mean(signal2_blanked[idx]))
window_corrs[i] = np.correlate(
signa11_blanked[idx] - np.mean(signa11_blanked[idx]),
signal2_blanked[idx] - np.mean(signal2_blanked[idx]),
)

max_window = np.argmax(window_corrs)

Expand Down Expand Up @@ -962,25 +965,26 @@ def _compute_session_alignment(
)
shifted_histograms[ses_idx, :] = shifted_histogram


nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0]))

windows = []
for i in range(non_rigid_windows.shape[0]):
idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)]
idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)]
windows.append(idxs)
# TODO: check assumptions these are always the same size

windows = np.vstack(windows)

# import matplotlib.pyplot as plt
# plt.plot(non_rigid_windows.T)
# plt.show()
# import matplotlib.pyplot as plt
# plt.plot(non_rigid_windows.T)
# plt.show()

windows1 = windows[::2, :]
windows2 = windows[1::2, :]

nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0], non_rigid_windows.shape[0]))
nonrigid_session_offsets_matrix = np.empty(
(shifted_histograms.shape[0], shifted_histograms.shape[0], non_rigid_windows.shape[0])
)

for i in range(shifted_histograms.shape[0]):
for j in range(shifted_histograms.shape[0]):
Expand All @@ -990,13 +994,13 @@ def _compute_session_alignment(
shifts = np.empty(shifts1.size + shifts2.size)
# breakpoint()
shifts[::2] = shifts1
shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2# np.shifts2
# breakpoint()
shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2 # np.shifts2
# breakpoint()
nonrigid_session_offsets_matrix[i, j, :] = shifts

# TODO: there are gaps in between rect, rect seems weird, they are non-overlapping :S

# breakpoint()
# breakpoint()
# Then compute the nonrigid shifts
# nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
# shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
Expand Down

0 comments on commit 62407b1

Please sign in to comment.