Skip to content

Commit

Permalink
Rename straxen_type to old_type and fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Nov 15, 2024
1 parent f6ced89 commit e77da41
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 42 deletions.
2 changes: 1 addition & 1 deletion straxen/plugins/events/event_basics_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _set_dtype_requirements(self):
self.peak_properties = list(self.peak_properties)
self.peak_properties += [
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("straxen_type", np.int8, "Old straxen type of the peak(let)"),
("old_type", np.int8, "Old type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]
Expand Down
49 changes: 22 additions & 27 deletions straxen/plugins/peaklets/peaklet_classification_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ class PeakletClassificationSOM(PeakletClassification):
__version__ = "0.2.0"
child_plugin = True

dtype = strax.peak_interval_dtype + [
("type", np.int8, "Classification of the peak(let)"),
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("old_type", np.int8, "Old type of the peak(let)"),
("som_type", np.int8, "SOM type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]

som_files = straxen.URLConfig(
default="resource://xedocs://som_classifiers?attr=value&version=v1&run_id=045000&fmt=npy"
)
Expand All @@ -44,17 +53,6 @@ class PeakletClassificationSOM(PeakletClassification):
),
)

def infer_dtype(self):
dtype = strax.peak_interval_dtype + [
("type", np.int8, "Classification of the peak(let)"),
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("straxen_type", np.int8, "Old straxen type of the peak(let)"),
("som_type", np.int8, "SOM type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]
return dtype

def setup(self):
self.som_weight_cube = self.som_files["weight_cube"]
self.som_img = self.som_files["som_img"]
Expand All @@ -70,7 +68,7 @@ def compute(self, peaklets):

peaklet_with_som = np.zeros(len(peaklets_classifcation), dtype=self.dtype)
strax.copy_to_buffer(peaklets_classifcation, peaklet_with_som, "_copy_peaklets_information")
peaklet_with_som["straxen_type"] = peaklets_classifcation["type"]
peaklet_with_som["old_type"] = peaklets_classifcation["type"]
del peaklets_classifcation

# SOM classification
Expand All @@ -80,22 +78,20 @@ def compute(self, peaklets):

peaklets_w_type = peaklets_w_type[_is_s1_or_s2]

som_type, x_som, y_som = recall_populations(
som_sub_type, x_som, y_som = recall_populations(
peaklets_w_type, self.som_weight_cube, self.som_img, self.som_norm_factors
)
peaklet_with_som["som_sub_type"][_is_s1_or_s2] = som_type
peaklet_with_som["loc_x_som"][_is_s1_or_s2] = x_som
peaklet_with_som["loc_y_som"][_is_s1_or_s2] = y_som

strax_type = som_type_to_type(
som_type, self.som_s1_array, self.som_s2_array, self.som_s3_array, self.som_s0_array
som_sub_type, self.som_s1_array, self.som_s2_array, self.som_s3_array, self.som_s0_array
)

peaklet_with_som["som_sub_type"][_is_s1_or_s2] = som_sub_type
peaklet_with_som["loc_x_som"][_is_s1_or_s2] = x_som
peaklet_with_som["loc_y_som"][_is_s1_or_s2] = y_som
peaklet_with_som["som_type"][_is_s1_or_s2] = strax_type
if self.use_som_as_default:
peaklet_with_som["type"][_is_s1_or_s2] = strax_type
else:
peaklet_with_som["type"] = peaklet_with_som["straxen_type"]
peaklet_with_som["type"] = peaklet_with_som["old_type"]

return peaklet_with_som

Expand All @@ -105,7 +101,7 @@ def recall_populations(dataset, weight_cube, som_cls_img, norm_factors):
a dataset and a set of normalization factors.
In theory, if these 5 things are provided, this function should output
the original data back with one added field with the name "SOM_type"
the original data back with one added field with the name "som_sub_type"
weight_cube: SOM weight cube (3D array)
som_cls_img: SOM reference image as a numpy array
dataset: Data to preform the recall on (Should be peaklet level data)
Expand All @@ -130,15 +126,14 @@ def recall_populations(dataset, weight_cube, som_cls_img, norm_factors):
# preform a recall of the dataset with the weight cube
# assign each population color a number (can do from previous function)
ref_map = generate_color_ref_map(som_cls_img, unique_colors, xdim, ydim)
som_cls_array = np.empty(len(dataset["area"]))
som_cls_array[:] = np.nan
som_cls_array = np.full(len(dataset["area"]), np.nan)
# Make new numpy structured array to save the SOM cls data
data_with_SOM_cls = rfn.append_fields(dataset, "SOM_type", som_cls_array)
# preforms the recall and assigns SOM_type label
data_with_SOM_cls = rfn.append_fields(dataset, "som_sub_type", som_cls_array)
# preforms the recall and assigns som_sub_type label
output_data, x_som, y_som = som_cls_recall(
data_with_SOM_cls, decile_transform_check, weight_cube, ref_map
)
return output_data["SOM_type"], x_som, y_som
return output_data["som_sub_type"], x_som, y_som


def generate_color_ref_map(color_image, unique_colors, xdim, ydim):
Expand All @@ -159,7 +154,7 @@ def som_cls_recall(array_to_fill, data_in_som_fmt, weight_cube, reference_map):
)
w_neuron = np.argmin(distances, axis=0)
x_idx, y_idx = np.unravel_index(w_neuron, (som_xdim, som_ydim))
array_to_fill["SOM_type"] = reference_map[x_idx, y_idx]
array_to_fill["som_sub_type"] = reference_map[x_idx, y_idx]
return array_to_fill, x_idx, y_idx


Expand Down
4 changes: 2 additions & 2 deletions straxen/plugins/peaks/peak_basics_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def infer_dtype(self):
dtype = super().infer_dtype()
additional_fields = [
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("straxen_type", np.int8, "Old straxen type of the peak(let)"),
("old_type", np.int8, "Old type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]
Expand All @@ -25,6 +25,6 @@ def infer_dtype(self):

def compute(self, peaks):
peak_basics = super().compute(peaks)
fields_to_copy = ("som_sub_type", "straxen_type", "loc_x_som", "loc_y_som")
fields_to_copy = ("som_sub_type", "old_type", "loc_x_som", "loc_y_som")
strax.copy_to_buffer(peaks, peak_basics, "_copy_som_information", fields_to_copy)
return peak_basics
6 changes: 3 additions & 3 deletions straxen/plugins/peaks/peaks_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def infer_dtype(self):
# The merged dtype is argument position dependent!
# It must be first classification then peaklet
# Otherwise strax will raise an error when checking for the returned dtype!
merged_s2s_dtype = strax.merged_dtype((peaklet_classification_dtype, peaklets_dtype))
return merged_s2s_dtype
merged_dtype = strax.merged_dtype((peaklet_classification_dtype, peaklets_dtype))
return merged_dtype

def compute(self, peaklets, merged_s2s):
result = super().compute(peaklets, merged_s2s)
Expand All @@ -34,7 +34,7 @@ def compute(self, peaklets, merged_s2s):
_is_merged_s2 = np.isin(result["time"], merged_s2s["time"]) & np.isin(
strax.endtime(result), strax.endtime(merged_s2s)
)
result["straxen_type"][_is_merged_s2] = -1
result["old_type"][_is_merged_s2] = -1
result["som_sub_type"][_is_merged_s2] = -1

return result
29 changes: 20 additions & 9 deletions tests/test_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,31 @@ def setUp(self, context=straxen.contexts.demo):
self.n_top = self.st.config.get("n_top_pmts", 2)

# Make sure that the check is on. Otherwise we cannot test it.
self.st.set_config({"check_peak_sum_area_rtol": R_TOL_DEFAULT})
self.peaks_basics_compute = self.st.get_single_plugin(run_id, "peak_basics").compute
self.st.set_config(
{
"n_top_pmts": self.n_top,
"check_peak_sum_area_rtol": R_TOL_DEFAULT,
}
)
self.peaks_basics = self.st.get_single_plugin(run_id, "peak_basics")
self.peaklet_classification = self.st.get_single_plugin(run_id, "peaklet_classification")
self.dtype = strax.merged_dtype(
(
np.dtype(strax.peak_dtype(n_channels=self.n_top + 1, n_sum_wv_samples=10)),
self.peaklet_classification.dtype_for("peaklet_classification"),
)
)

@settings(deadline=None)
@given(
strategies.integers(min_value=0, max_value=TEST_DATA_LENGTH - 1),
)
def test_aft_equals1(self, test_peak_idx):
"""Fill top array with area 1."""
test_data = self.get_test_peaks(self.n_top)
test_data = self.get_test_peaks()
test_data[test_peak_idx]["area_per_channel"][: self.n_top] = 1
test_data[test_peak_idx]["area"] = np.sum(test_data[test_peak_idx]["area_per_channel"])
peaks = self.peaks_basics_compute(test_data)
peaks = self.peaks_basics.compute(test_data)
assert peaks[test_peak_idx]["area_fraction_top"] == 1

@settings(deadline=None)
Expand All @@ -50,21 +62,20 @@ def test_aft_equals1(self, test_peak_idx):
def test_bad_peak(self, off_by_factor, test_peak_idx):
"""Lets deliberately make some data that is not self-consistent to run into the error in the
test."""
test_data = self.get_test_peaks(self.n_top)
test_data = self.get_test_peaks()
test_data[test_peak_idx]["area_per_channel"][: self.n_top] = 1
area = np.sum(test_data[test_peak_idx]["area_per_channel"])

# Set the area field to a wrong value
area *= off_by_factor
test_data[test_peak_idx]["area"] = area
self.assertRaises(ValueError, self.peaks_basics_compute, test_data)
self.assertRaises(ValueError, self.peaks_basics.compute, test_data)

@staticmethod
def get_test_peaks(n_top, length=2, sum_wf_samples=10):
def get_test_peaks(self, length=2):
"""Generate some dummy peaks."""
test_data = np.zeros(
TEST_DATA_LENGTH,
dtype=strax.dtypes.peak_dtype(n_channels=n_top + 1, n_sum_wv_samples=sum_wf_samples),
dtype=self.dtype,
)
test_data["time"] = range(TEST_DATA_LENGTH)
test_data["time"] *= length * 2
Expand Down

0 comments on commit e77da41

Please sign in to comment.