From e77da4180c799f57baae148cea7e5722a9277b5f Mon Sep 17 00:00:00 2001 From: dachengx Date: Fri, 15 Nov 2024 12:59:50 -0600 Subject: [PATCH] Rename `straxen_type` to `old_type` and fix test --- straxen/plugins/events/event_basics_som.py | 2 +- .../peaklets/peaklet_classification_som.py | 49 +++++++++---------- straxen/plugins/peaks/peak_basics_som.py | 4 +- straxen/plugins/peaks/peaks_som.py | 6 +-- tests/test_peaks.py | 29 +++++++---- 5 files changed, 48 insertions(+), 42 deletions(-) diff --git a/straxen/plugins/events/event_basics_som.py b/straxen/plugins/events/event_basics_som.py index d8abe9517..ea162f819 100644 --- a/straxen/plugins/events/event_basics_som.py +++ b/straxen/plugins/events/event_basics_som.py @@ -20,7 +20,7 @@ def _set_dtype_requirements(self): self.peak_properties = list(self.peak_properties) self.peak_properties += [ ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), - ("straxen_type", np.int8, "Old straxen type of the peak(let)"), + ("old_type", np.int8, "Old type of the peak(let)"), ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), ] diff --git a/straxen/plugins/peaklets/peaklet_classification_som.py b/straxen/plugins/peaklets/peaklet_classification_som.py index dffd5a947..7534ae17d 100644 --- a/straxen/plugins/peaklets/peaklet_classification_som.py +++ b/straxen/plugins/peaklets/peaklet_classification_som.py @@ -31,6 +31,15 @@ class PeakletClassificationSOM(PeakletClassification): __version__ = "0.2.0" child_plugin = True + dtype = strax.peak_interval_dtype + [ + ("type", np.int8, "Classification of the peak(let)"), + ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), + ("old_type", np.int8, "Old type of the peak(let)"), + ("som_type", np.int8, "SOM type of the peak(let)"), + ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), + ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), + ] + som_files = straxen.URLConfig( default="resource://xedocs://som_classifiers?attr=value&version=v1&run_id=045000&fmt=npy" ) @@ -44,17 +53,6 @@ class PeakletClassificationSOM(PeakletClassification): ), ) - def infer_dtype(self): - dtype = strax.peak_interval_dtype + [ - ("type", np.int8, "Classification of the peak(let)"), - ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), - ("straxen_type", np.int8, "Old straxen type of the peak(let)"), - ("som_type", np.int8, "SOM type of the peak(let)"), - ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), - ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), - ] - return dtype - def setup(self): self.som_weight_cube = self.som_files["weight_cube"] self.som_img = self.som_files["som_img"] @@ -70,7 +68,7 @@ def compute(self, peaklets): peaklet_with_som = np.zeros(len(peaklets_classifcation), dtype=self.dtype) strax.copy_to_buffer(peaklets_classifcation, peaklet_with_som, "_copy_peaklets_information") - peaklet_with_som["straxen_type"] = peaklets_classifcation["type"] + peaklet_with_som["old_type"] = peaklets_classifcation["type"] del peaklets_classifcation # SOM classification @@ -80,22 +78,20 @@ def compute(self, peaklets): peaklets_w_type = peaklets_w_type[_is_s1_or_s2] - som_type, x_som, y_som = recall_populations( + som_sub_type, x_som, y_som = recall_populations( peaklets_w_type, self.som_weight_cube, self.som_img, self.som_norm_factors ) - peaklet_with_som["som_sub_type"][_is_s1_or_s2] = som_type - peaklet_with_som["loc_x_som"][_is_s1_or_s2] = x_som - peaklet_with_som["loc_y_som"][_is_s1_or_s2] = y_som - strax_type = som_type_to_type( - som_type, self.som_s1_array, self.som_s2_array, self.som_s3_array, self.som_s0_array + som_sub_type, self.som_s1_array, self.som_s2_array, self.som_s3_array, self.som_s0_array ) - + peaklet_with_som["som_sub_type"][_is_s1_or_s2] = som_sub_type + peaklet_with_som["loc_x_som"][_is_s1_or_s2] = x_som + peaklet_with_som["loc_y_som"][_is_s1_or_s2] = y_som peaklet_with_som["som_type"][_is_s1_or_s2] = strax_type if self.use_som_as_default: peaklet_with_som["type"][_is_s1_or_s2] = strax_type else: - peaklet_with_som["type"] = peaklet_with_som["straxen_type"] + peaklet_with_som["type"] = peaklet_with_som["old_type"] return peaklet_with_som @@ -105,7 +101,7 @@ def recall_populations(dataset, weight_cube, som_cls_img, norm_factors): a dataset and a set of normalization factors. In theory, if these 5 things are provided, this function should output - the original data back with one added field with the name "SOM_type" + the original data back with one added field with the name "som_sub_type" weight_cube: SOM weight cube (3D array) som_cls_img: SOM reference image as a numpy array dataset: Data to preform the recall on (Should be peaklet level data) @@ -130,15 +126,14 @@ def recall_populations(dataset, weight_cube, som_cls_img, norm_factors): # preform a recall of the dataset with the weight cube # assign each population color a number (can do from previous function) ref_map = generate_color_ref_map(som_cls_img, unique_colors, xdim, ydim) - som_cls_array = np.empty(len(dataset["area"])) - som_cls_array[:] = np.nan + som_cls_array = np.full(len(dataset["area"]), np.nan) # Make new numpy structured array to save the SOM cls data - data_with_SOM_cls = rfn.append_fields(dataset, "SOM_type", som_cls_array) - # preforms the recall and assigns SOM_type label + data_with_SOM_cls = rfn.append_fields(dataset, "som_sub_type", som_cls_array) + # preforms the recall and assigns som_sub_type label output_data, x_som, y_som = som_cls_recall( data_with_SOM_cls, decile_transform_check, weight_cube, ref_map ) - return output_data["SOM_type"], x_som, y_som + return output_data["som_sub_type"], x_som, y_som def generate_color_ref_map(color_image, unique_colors, xdim, ydim): @@ -159,7 +154,7 @@ def som_cls_recall(array_to_fill, data_in_som_fmt, weight_cube, reference_map): ) w_neuron = np.argmin(distances, axis=0) x_idx, y_idx = np.unravel_index(w_neuron, (som_xdim, som_ydim)) - array_to_fill["SOM_type"] = reference_map[x_idx, y_idx] + array_to_fill["som_sub_type"] = reference_map[x_idx, y_idx] return array_to_fill, x_idx, y_idx diff --git a/straxen/plugins/peaks/peak_basics_som.py b/straxen/plugins/peaks/peak_basics_som.py index 3b612ef93..e000c565b 100644 --- a/straxen/plugins/peaks/peak_basics_som.py +++ b/straxen/plugins/peaks/peak_basics_som.py @@ -16,7 +16,7 @@ def infer_dtype(self): dtype = super().infer_dtype() additional_fields = [ ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), - ("straxen_type", np.int8, "Old straxen type of the peak(let)"), + ("old_type", np.int8, "Old type of the peak(let)"), ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), ] @@ -25,6 +25,6 @@ def infer_dtype(self): def compute(self, peaks): peak_basics = super().compute(peaks) - fields_to_copy = ("som_sub_type", "straxen_type", "loc_x_som", "loc_y_som") + fields_to_copy = ("som_sub_type", "old_type", "loc_x_som", "loc_y_som") strax.copy_to_buffer(peaks, peak_basics, "_copy_som_information", fields_to_copy) return peak_basics diff --git a/straxen/plugins/peaks/peaks_som.py b/straxen/plugins/peaks/peaks_som.py index 522ff7270..3a28649b3 100644 --- a/straxen/plugins/peaks/peaks_som.py +++ b/straxen/plugins/peaks/peaks_som.py @@ -24,8 +24,8 @@ def infer_dtype(self): # The merged dtype is argument position dependent! # It must be first classification then peaklet # Otherwise strax will raise an error when checking for the returned dtype! - merged_s2s_dtype = strax.merged_dtype((peaklet_classification_dtype, peaklets_dtype)) - return merged_s2s_dtype + merged_dtype = strax.merged_dtype((peaklet_classification_dtype, peaklets_dtype)) + return merged_dtype def compute(self, peaklets, merged_s2s): result = super().compute(peaklets, merged_s2s) @@ -34,7 +34,7 @@ def compute(self, peaklets, merged_s2s): _is_merged_s2 = np.isin(result["time"], merged_s2s["time"]) & np.isin( strax.endtime(result), strax.endtime(merged_s2s) ) - result["straxen_type"][_is_merged_s2] = -1 + result["old_type"][_is_merged_s2] = -1 result["som_sub_type"][_is_merged_s2] = -1 return result diff --git a/tests/test_peaks.py b/tests/test_peaks.py index f8bd73025..73e0e56a6 100644 --- a/tests/test_peaks.py +++ b/tests/test_peaks.py @@ -21,8 +21,20 @@ def setUp(self, context=straxen.contexts.demo): self.n_top = self.st.config.get("n_top_pmts", 2) # Make sure that the check is on. Otherwise we cannot test it. - self.st.set_config({"check_peak_sum_area_rtol": R_TOL_DEFAULT}) - self.peaks_basics_compute = self.st.get_single_plugin(run_id, "peak_basics").compute + self.st.set_config( + { + "n_top_pmts": self.n_top, + "check_peak_sum_area_rtol": R_TOL_DEFAULT, + } + ) + self.peaks_basics = self.st.get_single_plugin(run_id, "peak_basics") + self.peaklet_classification = self.st.get_single_plugin(run_id, "peaklet_classification") + self.dtype = strax.merged_dtype( + ( + np.dtype(strax.peak_dtype(n_channels=self.n_top + 1, n_sum_wv_samples=10)), + self.peaklet_classification.dtype_for("peaklet_classification"), + ) + ) @settings(deadline=None) @given( @@ -30,10 +42,10 @@ def setUp(self, context=straxen.contexts.demo): ) def test_aft_equals1(self, test_peak_idx): """Fill top array with area 1.""" - test_data = self.get_test_peaks(self.n_top) + test_data = self.get_test_peaks() test_data[test_peak_idx]["area_per_channel"][: self.n_top] = 1 test_data[test_peak_idx]["area"] = np.sum(test_data[test_peak_idx]["area_per_channel"]) - peaks = self.peaks_basics_compute(test_data) + peaks = self.peaks_basics.compute(test_data) assert peaks[test_peak_idx]["area_fraction_top"] == 1 @settings(deadline=None) @@ -50,21 +62,20 @@ def test_aft_equals1(self, test_peak_idx): def test_bad_peak(self, off_by_factor, test_peak_idx): """Lets deliberately make some data that is not self-consistent to run into the error in the test.""" - test_data = self.get_test_peaks(self.n_top) + test_data = self.get_test_peaks() test_data[test_peak_idx]["area_per_channel"][: self.n_top] = 1 area = np.sum(test_data[test_peak_idx]["area_per_channel"]) # Set the area field to a wrong value area *= off_by_factor test_data[test_peak_idx]["area"] = area - self.assertRaises(ValueError, self.peaks_basics_compute, test_data) + self.assertRaises(ValueError, self.peaks_basics.compute, test_data) - @staticmethod - def get_test_peaks(n_top, length=2, sum_wf_samples=10): + def get_test_peaks(self, length=2): """Generate some dummy peaks.""" test_data = np.zeros( TEST_DATA_LENGTH, - dtype=strax.dtypes.peak_dtype(n_channels=n_top + 1, n_sum_wv_samples=sum_wf_samples), + dtype=self.dtype, ) test_data["time"] = range(TEST_DATA_LENGTH) test_data["time"] *= length * 2