diff --git a/mokapot/dataset.py b/mokapot/dataset.py index a6fb242..9f44006 100644 --- a/mokapot/dataset.py +++ b/mokapot/dataset.py @@ -421,8 +421,8 @@ def _calibrate_scores(self, scores, eval_fdr, desc=True): ) +@typechecked class OnDiskPsmDataset(PsmDataset): - @typechecked def __init__( self, filename_or_reader: Path | TabularDataReader, @@ -604,6 +604,32 @@ def update_labels(self, scores, target_column, eval_fdr=0.01, desc=True): desc=desc, ) + @staticmethod + def _hash_row(x: np.ndarray) -> int: + """ + Hash array for splitting of test/training sets. + + Parameters + ---------- + x : np.ndarray + Input array to be hashed. + + Returns + ------- + int + Computed hash of the input array. + """ + + def to_base_val(v): + """Return base python value also for numpy types""" + try: + return v.item() + except AttributeError: + return v + + tup = tuple(to_base_val(x) for x in x) + return crc32(str(tup).encode()) + def _split(self, folds, rng): """ Get the indices for random, even splits of the dataset. @@ -626,13 +652,7 @@ def _split(self, folds, rng): """ spectra = self.spectra_dataframe[self.spectrum_columns].values del self.spectra_dataframe - spectra = np.apply_along_axis( - # Need to cast to float, so that numpy 1.x and 2.x return the same - # string representation - lambda x: crc32(str(tuple(map(float, x))).encode()), - 1, - spectra, - ) + spectra = np.apply_along_axis(OnDiskPsmDataset._hash_row, 1, spectra) # sort values to get start position of unique hashes spectra_idx = np.argsort(spectra) diff --git a/tests/unit_tests/test_dataset.py b/tests/unit_tests/test_dataset.py index 1caec56..f25648c 100644 --- a/tests/unit_tests/test_dataset.py +++ b/tests/unit_tests/test_dataset.py @@ -4,7 +4,8 @@ import numpy as np import pandas as pd -from mokapot import LinearPsmDataset + +from mokapot import LinearPsmDataset, OnDiskPsmDataset def test_linear_init(psm_df_6): @@ -56,3 +57,19 @@ def test_update_labels(psm_df_6): real_labs = np.array([1, 1, 0, -1, -1, -1]) new_labs = dset._update_labels(scores, eval_fdr=0.5) assert np.array_equal(real_labs, new_labs) + + +def test_hash_row(): + x = np.array(["test.mzML", 870, 5902.639978936955, 890.522815122875], dtype=object) + assert OnDiskPsmDataset._hash_row(x) == 4196757312 + + x = np.array( + [ + "test.mzML", + np.int64(870), + np.float64(5902.639978936955), + np.float64(890.522815122875), + ], + dtype=object, + ) + assert OnDiskPsmDataset._hash_row(x) == 4196757312