Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use strings as ids in generators #3588

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_total_duration(self) -> float:

def get_unit_spike_train(
self,
unit_id,
unit_id: str | int,
segment_index: Union[int, None] = None,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import warnings
import numpy as np
from typing import Literal
from typing import Literal, Optional
from math import ceil

from .basesorting import SpikeVectorSortingSegment
Expand Down Expand Up @@ -134,7 +134,7 @@ def generate_sorting(
seed = _ensure_seed(seed)
rng = np.random.default_rng(seed)
num_segments = len(durations)
unit_ids = np.arange(num_units)
unit_ids = [str(id) for id in np.arange(num_units)]
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved

spikes = []
for segment_index in range(num_segments):
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def __init__(

"""

unit_ids = np.arange(num_units)
unit_ids = [str(id) for id in np.arange(num_units)]
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(sampling_frequency, unit_ids)

self.num_units = num_units
Expand All @@ -1138,6 +1138,7 @@ def __init__(
firing_rates=firing_rates,
refractory_period_seconds=self.refractory_period_seconds,
seed=segment_seed,
unit_ids=unit_ids,
t_start=None,
)
self.add_sorting_segment(segment)
Expand All @@ -1161,6 +1162,7 @@ def __init__(
firing_rates: float | np.ndarray,
refractory_period_seconds: float | np.ndarray,
seed: int,
unit_ids: list[str],
t_start: Optional[float] = None,
):
self.num_units = num_units
Expand All @@ -1177,7 +1179,8 @@ def __init__(
self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64")

self.segment_seed = seed
self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)}
self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids}

self.num_samples = math.ceil(sampling_frequency * duration)
super().__init__(t_start)

Expand Down Expand Up @@ -1280,7 +1283,7 @@ def __init__(
noise_block_size: int = 30000,
):

channel_ids = np.arange(num_channels)
channel_ids = [str(id) for id in np.arange(num_channels)]
dtype = np.dtype(dtype).name # Cast to string for serialization
if dtype not in ("float32", "float64"):
raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}")
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/core/tests/test_basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder):
assert snippets.get_num_segments() == len(duration)
assert snippets.get_num_channels() == num_channels

assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None))
assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None))

# annotations / properties
snippets.annotate(gre="ta")
Expand All @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder):
)

# missing property
snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1])
snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"])
values = snippets.get_property("string_property")
assert values[2] == ""

Expand All @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder):
snippets.set_property,
key="string_property_nan",
values=["hola", "chabon"],
ids=[0, 1],
ids=["0", "1"],
missing_value=np.nan,
)

# int properties without missing values raise an error
assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2])

snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200)
snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200)
values = snippets.get_property("int_property")
assert values.dtype.kind == "i"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def test_channelsaggregationrecording():

assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg))
assert np.allclose(
traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg)
traces2_0,
recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg),
)
assert np.allclose(
traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg)
traces3_2,
recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg),
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def get_dataset():
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=2205,
)

# TODO: the tests or the sorting analyzer make assumptions about the ids being integers
# So keeping this the way it was
integer_channel_ids = [int(id) for id in recording.get_channel_ids()]
integer_unit_ids = [int(id) for id in sorting.get_unit_ids()]

recording = recording.rename_channels(new_channel_ids=integer_channel_ids)
sorting = sorting.rename_units(new_unit_ids=integer_unit_ids)
return recording, sorting


Expand Down
22 changes: 13 additions & 9 deletions src/spikeinterface/core/tests/test_unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,43 @@
def test_basic_functions():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2])
assert np.array_equal(sorting2.unit_ids, [0, 2])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"])
assert np.array_equal(sorting2.unit_ids, ["0", "2"])
assert sorting2.get_parent() == sorting

sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"])
sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"])
assert np.array_equal(sorting3.unit_ids, ["a", "b"])

assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting2.get_unit_spike_train(unit_id="0", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting3.get_unit_spike_train(unit_id="a", segment_index=0),
)

assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting2.get_unit_spike_train(unit_id="2", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting3.get_unit_spike_train(unit_id="b", segment_index=0),
)


def test_failure_with_non_unique_unit_ids():
seed = 10
sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed)
with pytest.raises(AssertionError):
sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"])


def test_custom_cache_spike_vector():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"])
sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"])
cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True)
computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False)
assert np.all(cached_spike_vector == computed_spike_vector)
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def make_sorting_analyzer(sparse=True):
seed=2205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory", sparse=sparse)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("waveforms", **job_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def test_gh_curation():
Test curation using GitHub URI.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

# curated link:
# https://figurl.org/f?v=npm://@fi-sci/figurl-sortingview@12/dist&d=sha1://058ab901610aa9d29df565595a3cc2a81a1b08e5
gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json"
Expand Down Expand Up @@ -76,6 +79,8 @@ def test_sha1_curation():
Test curation using SHA1 URI.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

# from SHA1
# curated link:
Expand Down Expand Up @@ -105,6 +110,8 @@ def test_json_curation():
Test curation using a JSON file.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

# from curation.json
json_file = parent_folder / "sv-sorting-curation.json"
Expand Down Expand Up @@ -248,6 +255,8 @@ def test_json_no_merge_curation():
Test curation with no merges using a JSON file.
"""
sorting = generate_sorting(num_units=10)
unit_ids_as_int = [id for id in range(sorting.get_num_units())]
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_int)

json_file = parent_folder / "sv-sorting-curation-no-merge.json"
sorting_curated = apply_sortingview_curation(sorting, uri_or_json=json_file)
Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/extractors/tests/test_mdaextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def test_mda_extractors(create_cache_folder):
cache_folder = create_cache_folder
rec, sort = generate_ground_truth_recording(durations=[10.0], num_units=10)

ids_as_integers = [id for id in range(rec.get_num_channels())]
rec = rec.rename_channels(new_channel_ids=ids_as_integers)

ids_as_integers = [id for id in range(sort.get_num_units())]
sort = sort.rename_units(new_unit_ids=ids_as_integers)

MdaRecordingExtractor.write_recording(rec, cache_folder / "mdatest")
rec_mda = MdaRecordingExtractor(cache_folder / "mdatest")
probe = rec_mda.get_probe()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def get_dataset():
seed=2205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

# since templates are going to be averaged and this might be a problem for amplitude scaling
# we select the 3 units with the largest templates to split
analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False)
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/preprocessing/tests/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def test_clip():
rec1 = clip(rec, a_min=-1.5)
rec1.save(verbose=False)

traces0 = rec0.get_traces(segment_index=0, channel_ids=[1])
traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"])
assert traces0.shape[1] == 1

assert np.all(-2 <= traces0[0] <= 3)

traces1 = rec1.get_traces(segment_index=0, channel_ids=[0, 1])
traces1 = rec1.get_traces(segment_index=0, channel_ids=["0", "1"])
assert traces1.shape[1] == 2

assert np.all(-1.5 <= traces1[1])
Expand All @@ -34,11 +34,11 @@ def test_blank_staturation():
rec1 = blank_staturation(rec, quantile_threshold=0.01, direction="both", chunk_size=10000)
rec1.save(verbose=False)

traces0 = rec0.get_traces(segment_index=0, channel_ids=[1])
traces0 = rec0.get_traces(segment_index=0, channel_ids=["1"])
assert traces0.shape[1] == 1
assert np.all(traces0 < 3.0)

traces1 = rec1.get_traces(segment_index=0, channel_ids=[0])
traces1 = rec1.get_traces(segment_index=0, channel_ids=["0"])
assert traces1.shape[1] == 1
# use a smaller value to be sure
a_min = rec1._recording_segments[0].a_min
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def test_output_values():
expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)]
expected_weights /= np.sum(expected_weights)

si_interpolated_recording = spre.interpolate_bad_channels(recording, bad_channel_indexes, sigma_um=1, p=1)
si_interpolated_recording = spre.interpolate_bad_channels(
recording, bad_channel_ids=bad_channel_ids, sigma_um=1, p=1
)
si_interpolated = si_interpolated_recording.get_traces()

expected_ts = si_interpolated[:, 1:] @ expected_weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_normalize_by_quantile():
rec2 = normalize_by_quantile(rec, mode="by_channel")
rec2.save(verbose=False)

traces = rec2.get_traces(segment_index=0, channel_ids=[1])
traces = rec2.get_traces(segment_index=0, channel_ids=["1"])
assert traces.shape[1] == 1

rec2 = normalize_by_quantile(rec, mode="pool_channel")
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/tests/test_rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_rectify():
rec2 = rectify(rec)
rec2.save(verbose=False)

traces = rec2.get_traces(segment_index=0, channel_ids=[1])
traces = rec2.get_traces(segment_index=0, channel_ids=["1"])
assert traces.shape[1] == 1

# import matplotlib.pyplot as plt
Expand Down
10 changes: 10 additions & 0 deletions src/spikeinterface/qualitymetrics/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def small_sorting_analyzer():
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")
Expand Down Expand Up @@ -60,6 +65,11 @@ def sorting_analyzer_simple():
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
Expand Down
6 changes: 6 additions & 0 deletions src/spikeinterface/sortingcomponents/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ def make_dataset():
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
seed=2205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

return recording, sorting
Loading