Skip to content

Commit

Permalink
Merge branch 'main' of github.com:MouseLand/Kilosort
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Dec 17, 2024
2 parents b73f0fa + a5b43f0 commit 00f0ca8
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 13 deletions.
11 changes: 9 additions & 2 deletions docs/tutorials/plotting_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
"# Example plots using kilosort.data_tools"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Note that `kilosort.data_tools` was added in `v4.0.21`, so you will need to update Kilosort4 to at least that version to use these examples. This can be done using `pip install kilosort --upgrade`."
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down Expand Up @@ -42,7 +49,7 @@
"from kilosort.io import load_ops\n",
"from kilosort.data_tools import (\n",
" mean_waveform, cluster_templates, get_good_cluster, get_cluster_spikes,\n",
" get_spike_waveforms, get_best_channel\n",
" get_spike_waveforms, get_best_channels\n",
" )\n",
"\n",
"\n",
Expand Down Expand Up @@ -101,7 +108,7 @@
"# Time in s for spike time axis\n",
"t2 = spike_times / ops['fs']\n",
"# Get single-channel waveform for each spike\n",
"chan = get_best_channel(cluster_id, results_dir)\n",
"chan = get_best_channels(results_dir)[cluster_id]\n",
"waves = get_spike_waveforms(spike_times, results_dir, chan=chan)\n",
"\n",
"# Plot each waveform, using spike time as 3rd dimension\n",
Expand Down
12 changes: 7 additions & 5 deletions kilosort/data_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mean_waveform(cluster_id, results_dir, n_spikes=np.inf, bfile=None, best=Tru
"""
results_dir = Path(results_dir)
if best:
chan = get_best_channel(cluster_id, results_dir)
chan = get_best_channels(results_dir)[cluster_id]
else:
chan = None

Expand All @@ -45,12 +45,14 @@ def mean_waveform(cluster_id, results_dir, n_spikes=np.inf, bfile=None, best=Tru
return mean_wave


def get_best_channel(cluster_id, results_dir):
"""Get channel number with largest template norm for this cluster."""
def get_best_channels(results_dir):
"""Get channel numbers with largest template norm for each cluster."""
templates = np.load(results_dir / 'templates.npy')
chan = (templates**2).sum(axis=1).argmax(axis=-1)[cluster_id]
return chan
best_chans = (templates**2).sum(axis=1).argmax(axis=-1)
return best_chans

def get_best_channel(results_dir, cluster_id):
return get_best_channels(results_dir)[cluster_id]

def get_cluster_spikes(cluster_id, results_dir, n_spikes=np.inf):
"""Get `n_spikes` random spike times assigned to `cluster_id`."""
Expand Down
6 changes: 4 additions & 2 deletions kilosort/gui/probe_view_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def update_spots_variables(self, probe, template_args):
for ind, (xc, yc) in enumerate(zip(self.xc, self.yc)):
self.channel_map_dict[(xc, yc)] = ind

def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers, device):
def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers):
ops = {
'yc': self.yc, 'xc': self.xc, 'max_channel_distance': max_dist,
'x_centers': x_centers, 'settings': {'dmin': dmin, 'dminx': dminx},
Expand All @@ -131,7 +131,9 @@ def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers, device):
ops = template_centers(ops)
[ys, xs] = np.meshgrid(ops['yup'], ops['xup'])
ys, xs = ys.flatten(), xs.flatten()
iC, ds = nearest_chans(ys, self.yc, xs, self.xc, nC, device=device)
iC, ds = nearest_chans(
ys, self.yc, xs, self.xc, nC, device=self.gui.device
)

igood = ds[0,:] <= max_dist**2
iC = iC[:,igood]
Expand Down
12 changes: 8 additions & 4 deletions kilosort/gui/settings_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

_DEFAULT_DTYPE = 'int16'
_ALLOWED_FILE_TYPES = ['.bin', '.dat', '.bat', '.raw'] # For binary data
_PROBE_SETTINGS = [
'nearest_chans', 'dmin', 'dminx', 'max_channel_distance', 'x_centers'
]

class SettingsBox(QtWidgets.QGroupBox):
settingsUpdated = QtCore.Signal()
Expand Down Expand Up @@ -247,6 +250,8 @@ def setup(self):
)
inp = getattr(self, f'{k}_input')
inp.editingFinished.connect(self.update_parameter)
if k in _PROBE_SETTINGS:
inp.editingFinished.connect(self.show_probe_layout())

row_count += rspan
layout.addWidget(
Expand Down Expand Up @@ -550,10 +555,7 @@ def update_settings(self):

def get_probe_template_args(self):
epw = self.extra_parameters_window
template_args = [
epw.nearest_chans, epw.dmin, epw.dminx,
epw.max_channel_distance, epw.x_centers, self.gui.device
]
template_args = [getattr(epw, k) for k in _PROBE_SETTINGS]
return template_args

@QtCore.Slot()
Expand Down Expand Up @@ -862,6 +864,8 @@ def __init__(self, parent):
layout.addWidget(getattr(self, f'{k}_input'), row_count, col+3, 1, 2)
inp = getattr(self, f'{k}_input')
inp.editingFinished.connect(self.update_parameter)
if k in _PROBE_SETTINGS:
inp.editingFinished.connect(self.main_settings.show_probe_layout)

self.setLayout(layout)

Expand Down

0 comments on commit 00f0ca8

Please sign in to comment.