diff --git a/straxen/corrections_services.py b/straxen/corrections_services.py index ef2eda53b..561176b87 100644 --- a/straxen/corrections_services.py +++ b/straxen/corrections_services.py @@ -21,7 +21,6 @@ "s1_xyz_map_mlp", "fdc_map_mlp", "s1_aft_xyz_map", - "bayes_model", ] single_value_corrections = [ diff --git a/straxen/plugins/events/__init__.py b/straxen/plugins/events/__init__.py index e12992307..0a49f0eb9 100644 --- a/straxen/plugins/events/__init__.py +++ b/straxen/plugins/events/__init__.py @@ -52,9 +52,6 @@ from . import event_top_bottom_params from .event_top_bottom_params import * -from . import event_w_bayes_class -from .event_w_bayes_class import * - from . import s2_recon_pos_diff from .s2_recon_pos_diff import * diff --git a/straxen/plugins/events/event_w_bayes_class.py b/straxen/plugins/events/event_w_bayes_class.py deleted file mode 100644 index 232ebc635..000000000 --- a/straxen/plugins/events/event_w_bayes_class.py +++ /dev/null @@ -1,39 +0,0 @@ -import strax -import numpy as np - - -class EventwBayesClass(strax.Plugin): - """Append at event level the posterior probability for an S1, S2, alt_S1 and alt_S2.""" - - provides = "event_w_bayes_class" - depends_on = ("peak_classification_bayes", "event_basics") - data_kind = "events" - __version__ = "0.0.1" - - def infer_dtype(self): - dtype = [] - dtype += strax.time_fields - for name in ["s1", "s2", "alt_s1", "alt_s2"]: - dtype += [(f"{name}_ln_prob_s1", np.float32, f"Given an {name}, s1 ln probability")] - dtype += [(f"{name}_ln_prob_s2", np.float32, f"Given an {name}, s2 ln probability")] - - return dtype - - def compute(self, peaks, events): - split_peaks = strax.split_by_containment(peaks, events) - - # 1. Initialization - result = np.empty(len(events), dtype=self.dtype) - # 2. Set time and endtime for events - for name in ["s1", "s2", "alt_s1", "alt_s2"]: - result[f"{name}_ln_prob_s1"] = np.nan - result[f"{name}_ln_prob_s2"] = np.nan - result["time"] = events["time"] - result["endtime"] = events["endtime"] - # 3. Assign peaks features to main S1, main S2 in the event - for event_i, (event, sp) in enumerate(zip(events, split_peaks)): - for name in ["s1", "s2", "alt_s1", "alt_s2"]: - if event[f"{name}_index"] >= 0: - result[f"{name}_ln_prob_s1"][event_i] = sp["ln_prob_s1"][event[f"{name}_index"]] - result[f"{name}_ln_prob_s2"][event_i] = sp["ln_prob_s2"][event[f"{name}_index"]] - return result diff --git a/straxen/plugins/peaks/__init__.py b/straxen/plugins/peaks/__init__.py index 996aab675..cd6a094e5 100644 --- a/straxen/plugins/peaks/__init__.py +++ b/straxen/plugins/peaks/__init__.py @@ -45,6 +45,3 @@ from . import peak_corrections from .peak_corrections import * - -from . import peak_classification_bayes -from .peak_classification_bayes import * diff --git a/straxen/plugins/peaks/peak_classification_bayes.py b/straxen/plugins/peaks/peak_classification_bayes.py deleted file mode 100644 index 706d0b443..000000000 --- a/straxen/plugins/peaks/peak_classification_bayes.py +++ /dev/null @@ -1,224 +0,0 @@ -from scipy.special import logsumexp - -import strax -import straxen - -import numpy as np -import numba - - -export, __all__ = strax.exporter() - - -@export -class BayesPeakClassification(strax.Plugin): - """Bayes Peak classification Returns the ln probability of a each event belonging to the S1 and - S2 class. Uses conditional probabilities and data parameterization learned from wfsim data. More - info can be found here xenon:xenonnt:ahiguera:bayespeakclassification. - - :param peaks: peaks - :param waveforms: peaks waveforms in PE/ns - :param quantiles: quantiles in ns, calculate from a cumulative sum over the waveform, from zero - to the total area with normalized cumulative sum to determine the time - :return: the ln probability of a each peak belonging to S1 and S2 class - - """ - - provides = "peak_classification_bayes" - depends_on = "peaks" - __version__ = "0.0.3" - dtype = ( - strax.time_fields - + [("ln_prob_s1", np.float32, "S1 ln probability")] - + [("ln_prob_s2", np.float32, "S2 ln probability")] - ) - - # Descriptor configs - bayes_config_file = straxen.URLConfig( - default="resource://cmt://bayes_model?version=ONLINE&run_id=plugin.run_id&fmt=npy", - help="Bayes model, conditional probabilities tables and Bayes discrete bins", - ) - bayes_n_nodes = straxen.URLConfig( - default=50, help="Number of attributes(features) per waveform and quantile" - ) - n_bayes_classes = straxen.URLConfig(default=2, help="Number of label classes S1(1)/S2(2)") - - def setup(self): - self.class_prior = np.ones(self.n_bayes_classes) / self.n_bayes_classes - self.bins = self.bayes_config_file["bins"] - self.cpt = self.bayes_config_file["cprob"] - - def compute(self, peaks): - result = np.zeros(len(peaks), dtype=self.dtype) - - waveforms, quantiles = compute_wf_and_quantiles(peaks, self.bayes_n_nodes) - - ln_prob_s1, ln_prob_s2 = compute_inference( - self.bins, - self.bayes_n_nodes, - self.cpt, - self.n_bayes_classes, - self.class_prior, - waveforms, - quantiles, - ) - result["time"] = peaks["time"] - result["endtime"] = strax.endtime(peaks) - result["ln_prob_s1"] = ln_prob_s1 - result["ln_prob_s2"] = ln_prob_s2 - return result - - -def compute_wf_and_quantiles(peaks: np.ndarray, bayes_n_nodes: int): - """Compute waveforms and quantiles for a given number of nodes(attributes) - - :param peaks: - :param bayes_n_nodes: number of nodes or attributes - :return: waveforms and quantiles - - """ - data = peaks["data"].copy() - data[data < 0.0] = 0.0 - dt = peaks["dt"] - return _compute_wf_and_quantiles(data, dt, bayes_n_nodes) - - -@numba.njit(cache=True) -def _compute_wf_and_quantiles(data, sample_length, bayes_n_nodes: int): - waveforms = np.zeros((len(data), bayes_n_nodes)) - quantiles = np.zeros((len(data), bayes_n_nodes)) - - num_samples = data.shape[1] - step_size = int(num_samples / bayes_n_nodes) - steps = np.arange(0, num_samples + 1, step_size) - inter_points = np.linspace(0.0, 1.0 - (1.0 / bayes_n_nodes), bayes_n_nodes) - cumsum_steps = np.zeros(bayes_n_nodes + 1, dtype=np.float64) - frac_of_cumsum = np.zeros(num_samples + 1) - sample_number_div_dt = np.arange(0, num_samples + 1, 1) - for i, (samples, dt) in enumerate(zip(data, sample_length)): - # reset buffers - frac_of_cumsum[:] = 0 - cumsum_steps[:] = 0 - - frac_of_cumsum[1:] = np.cumsum(samples) - frac_of_cumsum[1:] = frac_of_cumsum[1:] / frac_of_cumsum[-1] - - cumsum_steps[:-1] = np.interp(inter_points, frac_of_cumsum, sample_number_div_dt * dt) - cumsum_steps[-1] = sample_number_div_dt[-1] * dt - quantiles[i] = cumsum_steps[1:] - cumsum_steps[:-1] - - for j in range(bayes_n_nodes): - waveforms[i][j] = np.sum(samples[steps[j] : steps[j + 1]]) - waveforms[i] /= step_size * dt - - return waveforms, quantiles - - -def compute_inference( - bins: np.ndarray, - bayes_n_nodes: int, - cpt: np.ndarray, - n_bayes_classes: int, - class_prior: np.ndarray, - waveforms: np.ndarray, - quantiles: np.ndarray, -): - """Bin the waveforms and quantiles according to Bayes bins and compute inference. - - :param bins: Bayes bins - :param bayes_n_nodes: number of nodes or attributes - :param cpt: conditional probability tables - :param n_bayes_classes: number of classes - :param class_prior: class_prior - :param waveforms: waveforms - :param quantiles: quantiles - :return: ln probability per class S1/S2 - - """ - # Bin the waveforms and quantiles. - waveform_bin_edges = bins[0, :][bins[0, :] > -1] - waveform_num_bin_edges = len(waveform_bin_edges) - quantile_bin_edges = bins[1, :][bins[1, :] > -1] - quantile_num_bin_edges = len(quantile_bin_edges) - waveform_values = np.digitize(waveforms, bins=waveform_bin_edges) - 1 - waveform_values = np.clip(waveform_values, 0, waveform_num_bin_edges - 2) - - quantile_values = np.digitize(quantiles, bins=quantile_bin_edges) - 1 - quantile_values = np.clip(quantile_values, 0, quantile_num_bin_edges - 2) - - lnposterior = get_log_posterior( - bayes_n_nodes, - waveforms, - cpt, - waveform_num_bin_edges, - quantile_num_bin_edges, - n_bayes_classes, - waveform_values, - quantile_values, - ) - - lnposterior_sumsamples = np.sum(lnposterior, axis=1) - lnposterior_sumsamples = lnposterior_sumsamples + np.log(class_prior) - - # If you need to numbafy, use - # github.com/XENONnT/straxen/blob/c06a611a5a4289c65b9830fc1e68027149b6a5fd/straxen/plugins/bayes_peak_classification.py#L119 # noqa - normalization = logsumexp(lnposterior_sumsamples, axis=1) - ln_prob_s1 = lnposterior_sumsamples[:, 0] - normalization - ln_prob_s2 = lnposterior_sumsamples[:, 1] - normalization - return ln_prob_s1, ln_prob_s2 - - -def get_log_posterior( - bayes_n_nodes: int, - waveforms: np.ndarray, - cpt: np.ndarray, - waveform_num_bin_edges: int, - quantile_num_bin_edges: int, - n_bayes_classes: int, - waveform_values: np.ndarray, - quantile_values: np.ndarray, -) -> np.ndarray: - """# TODO, add a description what we are computing here - - :param bayes_n_nodes: number of nodes or attributes - :param waveforms: waveforms - :param cpt: conditional probability tables - :param waveform_num_bin_edges: number of bins for waveforms - :param quantile_num_bin_edges: number of bins for quantiles - :param n_bayes_classes: number of classes - :param waveform_values: digitized waveforms - :param quantile_values: digitized quantiles - :return: log-posterior for waveforms and quantiles. NB! This is not normalized - - """ - wf_posterior = _get_log_posterior( - nodes=bayes_n_nodes, - n_bins=waveform_num_bin_edges - 1, - n_classes=n_bayes_classes, - cpt=cpt[:bayes_n_nodes], - wf_len=len(waveforms), - wf_values=waveform_values, - ) - quantile_posterior = _get_log_posterior( - nodes=bayes_n_nodes, - n_bins=quantile_num_bin_edges - 1, - n_classes=n_bayes_classes, - cpt=cpt[bayes_n_nodes:], - wf_len=len(waveforms), - wf_values=quantile_values, - ) - - lnposterior = np.zeros((len(waveforms), bayes_n_nodes * 2, n_bayes_classes)) - lnposterior[:, :bayes_n_nodes] = wf_posterior - lnposterior[:, bayes_n_nodes:] = quantile_posterior - return lnposterior - - -@numba.njit -def _get_log_posterior(nodes, n_bins, n_classes, cpt, wf_len, wf_values): - """Wrapper for extracting the ln posterior inside compute_inference.""" - ln_posterior = np.zeros((wf_len, nodes, n_classes)) - for i in range(nodes): - distribution = cpt[i, :n_bins, :] - ln_posterior[:, i, :] = np.log(distribution[wf_values[:, i], :]) - return ln_posterior diff --git a/tests/plugins/_core.py b/tests/plugins/_core.py index 6d22f0ad2..39defd7cc 100644 --- a/tests/plugins/_core.py +++ b/tests/plugins/_core.py @@ -87,7 +87,7 @@ def run_pytest_from_main(): """Build new unit test for provided functions. For example, you might want to run it for a single module, in that case you don't want to run - ALL the tests. So you can do e.g. `python bayes_plugin.py` where we only collect the tests + ALL the tests. So you can do e.g. `python peak_building.py` where we only collect the tests defined in that module. """ diff --git a/tests/plugins/bayes_plugin.py b/tests/plugins/bayes_plugin.py deleted file mode 100644 index 503db9668..000000000 --- a/tests/plugins/bayes_plugin.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Run with python tests/plugins/bayes_plugin.py.""" - -from _core import PluginTestAccumulator, PluginTestCase, run_pytest_from_main -import numpy as np - - -@PluginTestAccumulator.register("test_bayes_bins") -def test_bayes_bins(self: PluginTestCase): - """Test bins are monotonic, if no, issue w/ config file (training problem)""" - target = "peak_classification_bayes" - plugin = self.st.get_single_plugin(self.run_id, target) - bins = plugin.bayes_config_file["bins"] - waveform_bin_edges = bins[0, :][bins[0, :] > -1] - monotonic = np.all(np.diff(waveform_bin_edges) > 0) - self.assertTrue(monotonic) - - -@PluginTestAccumulator.register("test_bayes_inference") -def test_bayes_inference(self: PluginTestCase): - """Test inference, ln prob <= 0.""" - target = "peak_classification_bayes" - self.st.make(self.run_id, "peaks") - self.st.make(self.run_id, target) - bayes = self.st.get_array(self.run_id, target) - assert not np.any(bayes["ln_prob_s1"] > 0) - - -if __name__ == "__main__": - run_pytest_from_main() diff --git a/tests/plugins/test_plugins.py b/tests/plugins/test_plugins.py index cb1a04e5c..379372d96 100644 --- a/tests/plugins/test_plugins.py +++ b/tests/plugins/test_plugins.py @@ -6,10 +6,9 @@ import inspect # Need import to attach new tests to the PluginTestAccumulator -import bayes_plugin -import event_building import peak_building import posrec_plugins +import event_building import nv_processing import local_minimum_plugin