diff --git a/src/pynwb/ndx_binned_spikes/__init__.py b/src/pynwb/ndx_binned_spikes/__init__.py index bf83b3d..0f169fa 100644 --- a/src/pynwb/ndx_binned_spikes/__init__.py +++ b/src/pynwb/ndx_binned_spikes/__init__.py @@ -156,7 +156,7 @@ def get_data_for_condition(self, condition_index): if not self.has_multiple_conditions: return self.data - mask = self.condition_indices == condition_index + mask = self.condition_indices[:] == condition_index binned_spikes_for_unit = self.data[:, mask, :] return binned_spikes_for_unit diff --git a/src/pynwb/tests/test_binned_aligned_spikes.py b/src/pynwb/tests/test_binned_aligned_spikes.py index 9b7126b..6c75c61 100644 --- a/src/pynwb/tests/test_binned_aligned_spikes.py +++ b/src/pynwb/tests/test_binned_aligned_spikes.py @@ -295,6 +295,11 @@ def test_roundtrip_acquisition(self): assert read_binned_aligned_spikes.number_of_bins == number_of_bins assert read_binned_aligned_spikes.number_of_events == number_of_events assert read_binned_aligned_spikes.number_of_conditions == number_of_conditions + + expected_data_condition1 = self.binned_aligned_spikes.get_data_for_condition(condition_index=2) + data_condition1 = read_binned_aligned_spikes.get_data_for_condition(condition_index=2) + + np.testing.assert_equal(data_condition1, expected_data_condition1) def test_roundtrip_processing_module(self): self.binned_aligned_spikes = mock_BinnedAlignedSpikes()