diff --git a/straxen/contexts.py b/straxen/contexts.py index 1d37e8463..7714273ee 100644 --- a/straxen/contexts.py +++ b/straxen/contexts.py @@ -26,6 +26,8 @@ straxen.EnergyEstimates, straxen.EventInfoDouble, straxen.DistinctChannels, + straxen.PeakPositionsMLP, + straxen.PeakPositionsCNF, ], check_available=("peak_basics", "event_basics"), store_run_fields=("name", "number", "start", "end", "livetime", "mode", "source"), diff --git a/straxen/plugins/events/_event_s1_positions_base.py b/straxen/plugins/events/_event_s1_positions_base.py index be45b1958..e48530d86 100644 --- a/straxen/plugins/events/_event_s1_positions_base.py +++ b/straxen/plugins/events/_event_s1_positions_base.py @@ -20,7 +20,6 @@ class EventS1PositionBase(strax.Plugin): algorithm: Optional[str] = None compressor = "zstd" - parallel = True min_s1_area_s1_posrec = straxen.URLConfig( help="Skip reconstruction if area (PE) is less than this", diff --git a/straxen/plugins/events/_event_s2_positions_base.py b/straxen/plugins/events/_event_s2_positions_base.py index 923993acd..14d4b4ade 100644 --- a/straxen/plugins/events/_event_s2_positions_base.py +++ b/straxen/plugins/events/_event_s2_positions_base.py @@ -17,11 +17,10 @@ class EventS2PositionBase(strax.Plugin): algorithm: Optional[str] = None compressor = "zstd" - parallel = True min_reconstruction_area = straxen.URLConfig( help="Skip reconstruction if area (PE) is less than this", - default=10, + default=0, infer_type=False, ) n_top_pmts = straxen.URLConfig( diff --git a/straxen/plugins/events/event_position_uncertainty.py b/straxen/plugins/events/event_position_uncertainty.py index 7f06aa173..047f44a9e 100644 --- a/straxen/plugins/events/event_position_uncertainty.py +++ b/straxen/plugins/events/event_position_uncertainty.py @@ -2,7 +2,7 @@ import strax import straxen from straxen.plugins.defaults import DEFAULT_POSREC_ALGO -from straxen.plugins.peaks.peak_positions_cnf import PeakPositionsCNF +from straxen.plugins.peaklets.peaklet_positions_cnf import PeakletPositionsCNF export, __all__ = strax.exporter() @@ -305,7 +305,7 @@ def compute(self, events): avg_theta = np.arctan2(events[f"{type_}_y_cnf"], events[f"{type_}_x_cnf"]) - theta_diff = PeakPositionsCNF.calculate_theta_diff(theta_array, avg_theta) + theta_diff = PeakletPositionsCNF.calculate_theta_diff(theta_array, avg_theta) # Store uncertainties result[f"{type_}_r_position_uncertainty"] = (r_max - r_min) / 2 diff --git a/straxen/plugins/events/s2_recon_pos_diff.py b/straxen/plugins/events/s2_recon_pos_diff.py index 75b87e8a9..846680faa 100644 --- a/straxen/plugins/events/s2_recon_pos_diff.py +++ b/straxen/plugins/events/s2_recon_pos_diff.py @@ -15,7 +15,6 @@ class S2ReconPosDiff(strax.Plugin): __version__ = "0.0.3" - parallel = True depends_on = "event_basics" provides = "s2_recon_pos_diff" save_when = strax.SaveWhen.EXPLICIT diff --git a/straxen/plugins/merged_s2s/__init__.py b/straxen/plugins/merged_s2s/__init__.py index f4e18c38e..57d81c023 100644 --- a/straxen/plugins/merged_s2s/__init__.py +++ b/straxen/plugins/merged_s2s/__init__.py @@ -1,2 +1,8 @@ from . import merged_s2s from .merged_s2s import * + +from . import merged_s2_positions_cnf +from .merged_s2_positions_cnf import * + +from . import merged_s2_positions_mlp +from .merged_s2_positions_mlp import * diff --git a/straxen/plugins/merged_s2s/merged_s2_positions.py b/straxen/plugins/merged_s2s/merged_s2_positions.py new file mode 100644 index 000000000..5da7db084 --- /dev/null +++ b/straxen/plugins/merged_s2s/merged_s2_positions.py @@ -0,0 +1,19 @@ +import strax +from straxen.plugins.peaks._peak_positions_base import PeakPositionsBase + +export, __all__ = strax.exporter() + + +@export +class MergedS2Positions(PeakPositionsBase): + + __version__ = "0.0.0" + child_plugin = True + provides = "merged_s2_positions" + depends_on = ( + "merged_s2_positions_mlp", + "merged_s2_positions_cnf", + ) + + def compute(self, merged_s2s): + return super().compute(merged_s2s) diff --git a/straxen/plugins/merged_s2s/merged_s2_positions_cnf.py b/straxen/plugins/merged_s2s/merged_s2_positions_cnf.py new file mode 100644 index 000000000..bc3843341 --- /dev/null +++ b/straxen/plugins/merged_s2s/merged_s2_positions_cnf.py @@ -0,0 +1,17 @@ +import strax +from straxen.plugins.peaklets.peaklet_positions_cnf import PeakletPositionsCNF + +export, __all__ = strax.exporter() + + +@export +class MergedS2PositionsCNF(PeakletPositionsCNF): + + __version__ = "0.0.0" + child_plugin = True + algorithm = "cnf" + depends_on = "merged_s2s" + provides = "merged_s2_positions_cnf" + + def compute(self, merged_s2s): + return super().compute(merged_s2s) diff --git a/straxen/plugins/merged_s2s/merged_s2_positions_mlp.py b/straxen/plugins/merged_s2s/merged_s2_positions_mlp.py new file mode 100644 index 000000000..ef7fe7347 --- /dev/null +++ b/straxen/plugins/merged_s2s/merged_s2_positions_mlp.py @@ -0,0 +1,17 @@ +import strax +from straxen.plugins.peaklets.peaklet_positions_mlp import PeakletPositionsMLP + +export, __all__ = strax.exporter() + + +@export +class MergedS2PositionsMLP(PeakletPositionsMLP): + + __version__ = "0.0.0" + child_plugin = True + algorithm = "mlp" + depends_on = "merged_s2s" + provides = "merged_s2_positions_mlp" + + def compute(self, merged_s2s): + return super().compute(merged_s2s) diff --git a/straxen/plugins/peaklets/__init__.py b/straxen/plugins/peaklets/__init__.py index 493a62d0a..9ae66bb4d 100644 --- a/straxen/plugins/peaklets/__init__.py +++ b/straxen/plugins/peaklets/__init__.py @@ -6,3 +6,9 @@ from . import peaklet_classification_som from .peaklet_classification_som import * + +from . import peaklet_positions_cnf +from .peaklet_positions_cnf import * + +from . import peaklet_positions_mlp +from .peaklet_positions_mlp import * diff --git a/straxen/plugins/peaklets/_peaklet_positions_base.py b/straxen/plugins/peaklets/_peaklet_positions_base.py new file mode 100644 index 000000000..9c62521fc --- /dev/null +++ b/straxen/plugins/peaklets/_peaklet_positions_base.py @@ -0,0 +1,105 @@ +from typing import Optional +from warnings import warn + +import numpy as np +import strax +import straxen + +export, __all__ = strax.exporter() + + +@export +class PeakletPositionsBase(strax.Plugin): + """Base class for reconstructions. + + This class should only be used when subclassed for the different algorithms. Provides + x_algorithm, y_algorithm for all peaklets > than min_reconstruction_area based on the top array. + + """ + + __version__ = "0.0.0" + + depends_on = "peaklets" + algorithm: Optional[str] = None + compressor = "zstd" + + min_reconstruction_area = straxen.URLConfig( + help="Skip reconstruction if area (PE) is less than this", + default=0, + infer_type=False, + ) + + n_top_pmts = straxen.URLConfig( + default=straxen.n_top_pmts, infer_type=False, help="Number of top PMTs" + ) + + def infer_dtype(self): + if self.algorithm is None: + raise NotImplementedError( + f"Base class should not be used without algorithm as done in {__class__.__name__}" + ) + dtype = [ + ( + "x_" + self.algorithm, + np.float32, + f"Reconstructed {self.algorithm} S2 X position (cm), uncorrected", + ), + ( + "y_" + self.algorithm, + np.float32, + f"Reconstructed {self.algorithm} S2 Y position (cm), uncorrected", + ), + ] + dtype += strax.time_fields + return dtype + + def get_tf_model(self): + """Simple wrapper to have several tf_model_mlp, tf_model_cnf, .. + + point to this same function in the compute method + + """ + model = getattr(self, f"tf_model_{self.algorithm}", None) + if model is None: + warn( + f"Setting model to None for {self.__class__.__name__} will " + f"set only nans as output for {self.algorithm}" + ) + if isinstance(model, str): + raise ValueError( + f"open files from tf:// protocol! Got {model} " + "instead, see tests/test_posrec.py for examples." + ) + return model + + def compute(self, peaklets): + result = np.ones(len(peaklets), dtype=self.dtype) + result["time"], result["endtime"] = peaklets["time"], strax.endtime(peaklets) + + result["x_" + self.algorithm] *= np.nan + result["y_" + self.algorithm] *= np.nan + model = self.get_tf_model() + + if model is None: + # This plugin is disabled since no model is provided + return result + + # Keep large peaklets only + peak_mask = peaklets["area"] > self.min_reconstruction_area + if not np.sum(peak_mask): + # Nothing to do, and .predict crashes on empty arrays + return result + + # Getting actual position reconstruction + area_per_channel_top = peaklets["area_per_channel"][peak_mask, 0 : self.n_top_pmts] + with np.errstate(divide="ignore", invalid="ignore"): + area_per_channel_top = area_per_channel_top / np.max( + area_per_channel_top, axis=1 + ).reshape(-1, 1) + area_per_channel_top = area_per_channel_top.reshape(-1, self.n_top_pmts) + output = model.predict(area_per_channel_top, verbose=0) + + # writing output to the result + result["x_" + self.algorithm][peak_mask] = output[:, 0] + result["y_" + self.algorithm][peak_mask] = output[:, 1] + return result diff --git a/straxen/plugins/peaklets/peaklet_classification.py b/straxen/plugins/peaklets/peaklet_classification.py index 994215bdd..6385c0a2b 100644 --- a/straxen/plugins/peaklets/peaklet_classification.py +++ b/straxen/plugins/peaklets/peaklet_classification.py @@ -16,7 +16,6 @@ class PeakletClassification(strax.Plugin): provides: Union[str, tuple] = "peaklet_classification" depends_on = "peaklets" - parallel = True dtype = strax.peak_interval_dtype + [("type", np.int8, "Classification of the peak(let)")] s1_risetime_area_parameters = straxen.URLConfig( diff --git a/straxen/plugins/peaklets/peaklet_positions.py b/straxen/plugins/peaklets/peaklet_positions.py new file mode 100644 index 000000000..ebb3b0293 --- /dev/null +++ b/straxen/plugins/peaklets/peaklet_positions.py @@ -0,0 +1,19 @@ +import strax +from straxen.plugins.peaks._peak_positions_base import PeakPositionsBase + +export, __all__ = strax.exporter() + + +@export +class PeakletPositions(PeakPositionsBase): + + __version__ = "0.0.0" + child_plugin = True + provides = "peaklet_positions" + depends_on = ( + "peaklet_positions_cnf", + "peaklet_positions_mlp", + ) + + def compute(self, peaklets): + return super().compute(peaklets) diff --git a/straxen/plugins/peaklets/peaklet_positions_cnf.py b/straxen/plugins/peaklets/peaklet_positions_cnf.py new file mode 100644 index 000000000..d4cef16a0 --- /dev/null +++ b/straxen/plugins/peaklets/peaklet_positions_cnf.py @@ -0,0 +1,242 @@ +import numpy as np +import strax +import straxen +from ._peaklet_positions_base import PeakletPositionsBase + +export, __all__ = strax.exporter() + + +@export +class PeakletPositionsCNF(PeakletPositionsBase): + """Conditional Normalizing Flow for position reconstruction. + + This plugin reconstructs the position of peaklets using a conditional normalizing flow model. + It provides x and y coordinates of the reconstructed position, along with uncertainty contours + and uncertainty estimates in r and theta. For information on the model, see note_. + + .. _note: https://xe1t-wiki.lngs.infn.it/doku.php?id=xenon:xenonnt:juehang:flow_posrec_proposal_sr2 # noqa: E501 + + Configuration options: + - min_reconstruction_area: Minimum area (PE) required for reconstruction + - n_poly: Size of the uncertainty contour + - sig: Confidence level of the contour + - log_area_scale: Scaling parameter for log area + - n_top_pmts: Number of top PMTs + - pred_function: Path to the compiled JAX function for predictions + + """ + + __version__ = "0.0.0" + child_plugin = True + algorithm = "cnf" + provides = "peaklet_positions_cnf" + + n_poly = straxen.URLConfig( + default=16, + infer_type=False, + help="Size of uncertainty contour", + ) + + N_chunk_max = straxen.URLConfig( + default=4096, + infer_type=False, + help="Maximum size of chunk for vectorised JAX function", + ) + + sig = straxen.URLConfig( + default=0.393, + infer_type=False, + help="Confidence level of contour", + ) + + log_area_scale = straxen.URLConfig( + default=10, + infer_type=False, + help="Scaling parameter for log area", + ) + + pred_function = straxen.URLConfig( + default=( + "jax://resource://flow_20240730.tar.gz?" + "n_poly=plugin.n_poly&sig=plugin.sig&fmt=abs_path" + ), + help="Compiled JAX function", + ) + + @staticmethod + def calculate_theta_diff(theta_array, avg_theta): + """Calculate the difference between maximum and minimum angles from an array of angles by + normalizing the angular difference into the range [0, 2π). + + Parameters: + theta_array : np.ndarray + A 2D numpy array where each row represents a set of angles in radians. + avg_theta : np.ndarray + A 1D numpy array representing the average angle in radians for each + row in `theta_array`. + + Returns: + theta_diff : np.ndarray + A 1D numpy array with the difference between the maximum and minimum angles in radians + for each row in `theta_array` + + """ + # Correction to handle circular nature of angles + theta_array_shift = (theta_array - avg_theta[..., np.newaxis] + np.pi) % (2 * np.pi) + theta_min = np.min(theta_array_shift, axis=1) + theta_max = np.max(theta_array_shift, axis=1) + theta_diff = theta_max - theta_min + + return theta_diff + + def infer_dtype(self): + """Define the data type for the output. + + Returns: + dtype: Numpy dtype for the output array + + """ + dtype = [ + ( + ( + f"Reconstructed {self.algorithm} S2 X position (cm), uncorrected", + f"x_{self.algorithm}", + ), + np.float32, + ), + ( + ( + f"Reconstructed {self.algorithm} S2 Y position (cm), uncorrected", + f"y_{self.algorithm}", + ), + np.float32, + ), + ( + ("Position uncertainty contour (cm)", f"position_contour_{self.algorithm}"), + np.float32, + (self.n_poly + 1, 2), + ), + (("Position uncertainty in r (cm)", f"r_uncertainty_{self.algorithm}"), np.float32), + ( + ("Position uncertainty in theta (rad)", f"theta_uncertainty_{self.algorithm}"), + np.float32, + ), + ] + dtype += strax.time_fields + return dtype + + def vectorized_prediction_chunk(self, flow_condition): + """Compute predictions for a chunk of data. + + Args: + flow_condition: Input data for the flow model + + Returns: + xy: Predicted x and y coordinates + contour: Uncertainty contours + + """ + N_entries = flow_condition.shape[0] + if N_entries > self.N_chunk_max: + raise ValueError("Chunk greater than max size") + else: + inputs = np.zeros((self.N_chunk_max, self.n_top_pmts + 1)) + inputs[:N_entries] = flow_condition + xy, contour = self.pred_function(inputs) + return xy[:N_entries], contour[:N_entries] + + def prediction_loop(self, flow_condition): + """Compute predictions for arbitrary-size inputs using a loop. + + Args: + flow_condition: Input data for the flow model + + Returns: + xy: Predicted x and y coordinates + contour: Uncertainty contours + + """ + N_entries = flow_condition.shape[0] + if N_entries <= self.N_chunk_max: + return self.vectorized_prediction_chunk(flow_condition) + N_chunks = N_entries // self.N_chunk_max + + xy_list = [] + contour_list = [] + for i in range(N_chunks): + xy, contour = self.vectorized_prediction_chunk( + flow_condition[i * self.N_chunk_max : (i + 1) * self.N_chunk_max] + ) + xy_list.append(xy) + contour_list.append(contour) + + if N_chunks * self.N_chunk_max < N_entries: + xy, contour = self.vectorized_prediction_chunk( + flow_condition[(i + 1) * self.N_chunk_max :] + ) + xy_list.append(xy) + contour_list.append(contour) + return np.concatenate(xy_list, axis=0), np.concatenate(contour_list, axis=0) + + def compute(self, peaklets): + """Compute the position reconstruction for the given peaklets. + + Args: + peaklets: Input peaklet data + + Returns: + result: Array with reconstructed positions and uncertainties + + """ + # Initialize result array + result = np.ones(len(peaklets), dtype=self.dtype) + result["time"], result["endtime"] = peaklets["time"], strax.endtime(peaklets) + + # Set default values to NaN + result[f"x_{self.algorithm}"] *= np.nan + result[f"y_{self.algorithm}"] *= np.nan + result[f"position_contour_{self.algorithm}"] *= np.nan + result[f"r_uncertainty_{self.algorithm}"] *= np.nan + result[f"theta_uncertainty_{self.algorithm}"] *= np.nan + + # Keep large peaklets only + peaklet_mask = peaklets["area"] > self.min_reconstruction_area + if not np.sum(peaklet_mask): + # Nothing to do, and .predict crashes on empty arrays + return result + + # Prepare input data for the flow model + area_per_channel_top = peaklets["area_per_channel"][peaklet_mask, 0 : self.n_top_pmts] + total_top_areas = np.sum(area_per_channel_top, axis=1) + with np.errstate(divide="ignore", invalid="ignore"): + flow_data = np.concatenate( + [ + area_per_channel_top / total_top_areas[..., np.newaxis], + np.log(total_top_areas[..., np.newaxis]) / self.log_area_scale, + ], + axis=1, + ) + + # Get position reconstruction + xy, contours = self.prediction_loop(flow_data) + + # Write output to the result array + result[f"x_{self.algorithm}"][peaklet_mask] = xy[:, 0] + result[f"y_{self.algorithm}"][peaklet_mask] = xy[:, 1] + result[f"position_contour_{self.algorithm}"][peaklet_mask] = contours + + # Calculate uncertainties in r and theta + r_array = np.linalg.norm(contours, axis=2) + r_min = np.min(r_array, axis=1) + r_max = np.max(r_array, axis=1) + + theta_array = np.arctan2(contours[..., 1], contours[..., 0]) + + avg_theta = np.arctan2(xy[:, 1], xy[:, 0]) + + theta_diff = self.calculate_theta_diff(theta_array, avg_theta) + + result[f"r_uncertainty_{self.algorithm}"][peaklet_mask] = (r_max - r_min) / 2 + result[f"theta_uncertainty_{self.algorithm}"][peaklet_mask] = np.abs(theta_diff) / 2 + + return result diff --git a/straxen/plugins/peaklets/peaklet_positions_mlp.py b/straxen/plugins/peaklets/peaklet_positions_mlp.py new file mode 100644 index 000000000..b4b03b526 --- /dev/null +++ b/straxen/plugins/peaklets/peaklet_positions_mlp.py @@ -0,0 +1,33 @@ +import strax +import straxen +from ._peaklet_positions_base import PeakletPositionsBase + + +export, __all__ = strax.exporter() + + +@export +class PeakletPositionsMLP(PeakletPositionsBase): + """Multilayer Perceptron (MLP) neural net for position reconstruction.""" + + __version__ = "0.0.0" + child_plugin = True + algorithm = "mlp" + provides = "peaklet_positions_mlp" + gc_collect_after_compute = True + + tf_model_mlp = straxen.URLConfig( + default=( + "tf://" + "resource://" + f"cmt://{algorithm}_model" + "?version=ONLINE" + "&run_id=plugin.run_id" + "&fmt=abs_path" + ), + help=( + 'MLP model. Should be opened using the "tf" descriptor. ' + 'Set to "None" to skip computation' + ), + cache=3, + ) diff --git a/straxen/plugins/peaks/__init__.py b/straxen/plugins/peaks/__init__.py index 7bca4104e..217a7dc8d 100644 --- a/straxen/plugins/peaks/__init__.py +++ b/straxen/plugins/peaks/__init__.py @@ -13,15 +13,15 @@ from . import peak_s1_positions_cnn from .peak_s1_positions_cnn import * -from . import peak_positions -from .peak_positions import * - from . import peak_positions_cnf from .peak_positions_cnf import * from . import peak_positions_mlp from .peak_positions_mlp import * +from . import peak_positions +from .peak_positions import * + from . import peak_proximity from .peak_proximity import * diff --git a/straxen/plugins/peaks/_peak_positions_base.py b/straxen/plugins/peaks/_peak_positions_base.py index 84335dae5..14c036bb4 100644 --- a/straxen/plugins/peaks/_peak_positions_base.py +++ b/straxen/plugins/peaks/_peak_positions_base.py @@ -1,110 +1,21 @@ -"""Position reconstruction for Xenon-nT.""" - -from typing import Optional - -import numpy as np import strax -import straxen -from warnings import warn +from .peaks import Peaks export, __all__ = strax.exporter() -DEFAULT_POSREC_ALGO = "cnf" - @export -class PeakPositionsBaseNT(strax.Plugin): - """Base class for reconstructions. - - This class should only be used when subclassed for the different algorithms. Provides - x_algorithm, y_algorithm for all peaks > than min-reconstruction area based on the top array. - - """ +class PeakPositionsBase(Peaks): __version__ = "0.0.0" - - depends_on = "peaks" - algorithm: Optional[str] = None - compressor = "zstd" - parallel = True - - min_reconstruction_area = straxen.URLConfig( - help="Skip reconstruction if area (PE) is less than this", - default=10, - infer_type=False, - ) - - n_top_pmts = straxen.URLConfig( - default=straxen.n_top_pmts, infer_type=False, help="Number of top PMTs" - ) + child_plugin = True + save_when = strax.SaveWhen.ALWAYS def infer_dtype(self): - if self.algorithm is None: - raise NotImplementedError( - f"Base class should not be used without algorithm as done in {__class__.__name__}" - ) - dtype = [ - ( - "x_" + self.algorithm, - np.float32, - f"Reconstructed {self.algorithm} S2 X position (cm), uncorrected", - ), - ( - "y_" + self.algorithm, - np.float32, - f"Reconstructed {self.algorithm} S2 Y position (cm), uncorrected", - ), - ] - dtype += strax.time_fields - return dtype - - def get_tf_model(self): - """Simple wrapper to have several tf_model_mlp, tf_model_cnf, .. - - point to this same function in the compute method - - """ - model = getattr(self, f"tf_model_{self.algorithm}", None) - if model is None: - warn( - f"Setting model to None for {self.__class__.__name__} will " - f"set only nans as output for {self.algorithm}" - ) - if isinstance(model, str): - raise ValueError( - f"open files from tf:// protocol! Got {model} " - "instead, see tests/test_posrec.py for examples." - ) - return model - - def compute(self, peaks): - result = np.ones(len(peaks), dtype=self.dtype) - result["time"], result["endtime"] = peaks["time"], strax.endtime(peaks) - - result["x_" + self.algorithm] *= np.nan - result["y_" + self.algorithm] *= np.nan - model = self.get_tf_model() - - if model is None: - # This plugin is disabled since no model is provided - return result - - # Keep large peaks only - peak_mask = peaks["area"] > self.min_reconstruction_area - if not np.sum(peak_mask): - # Nothing to do, and .predict crashes on empty arrays - return result - - # Getting actual position reconstruction - area_per_channel_top = peaks["area_per_channel"][peak_mask, 0 : self.n_top_pmts] - with np.errstate(divide="ignore", invalid="ignore"): - area_per_channel_top = area_per_channel_top / np.max( - area_per_channel_top, axis=1 - ).reshape(-1, 1) - area_per_channel_top = area_per_channel_top.reshape(-1, self.n_top_pmts) - output = model.predict(area_per_channel_top, verbose=0) + return self.deps[f"peaklet_positions_{self.algorithm}"].dtype_for( + f"peaklet_positions_{self.algorithm}" + ) - # writing output to the result - result["x_" + self.algorithm][peak_mask] = output[:, 0] - result["y_" + self.algorithm][peak_mask] = output[:, 1] - return result + def compute(self, peaklets, merged_s2s): + _merged_s2s = strax.merge_arrs([merged_s2s], dtype=peaklets.dtype, replacing=True) + return super().compute(peaklets, _merged_s2s) diff --git a/straxen/plugins/peaks/_peak_s1_positions_base.py b/straxen/plugins/peaks/_peak_s1_positions_base.py index 668ed4ea0..486629a3b 100644 --- a/straxen/plugins/peaks/_peak_s1_positions_base.py +++ b/straxen/plugins/peaks/_peak_s1_positions_base.py @@ -25,7 +25,6 @@ class PeakS1PositionBase(strax.Plugin): algorithm: Optional[str] = None compressor = "zstd" - parallel = True min_s1_area_s1_posrec = straxen.URLConfig( help="Skip reconstruction if area (PE) is less than this", default=1000, infer_type=False diff --git a/straxen/plugins/peaks/peak_basics.py b/straxen/plugins/peaks/peak_basics.py index 8f604c311..00cd72ba1 100644 --- a/straxen/plugins/peaks/peak_basics.py +++ b/straxen/plugins/peaks/peak_basics.py @@ -16,7 +16,6 @@ class PeakBasics(strax.Plugin): """ __version__ = "0.1.4" - parallel = True depends_on = "peaks" provides = "peak_basics" diff --git a/straxen/plugins/peaks/peak_positions.py b/straxen/plugins/peaks/peak_positions.py index 7407b3710..1cab2b6be 100644 --- a/straxen/plugins/peaks/peak_positions.py +++ b/straxen/plugins/peaks/peak_positions.py @@ -8,7 +8,7 @@ @export -class PeakPositionsNT(strax.MergeOnlyPlugin): +class PeakPositions(strax.MergeOnlyPlugin): """Merge the reconstructed algorithms of the different algorithms into a single one that can be used in Event Basics. @@ -21,13 +21,13 @@ class PeakPositionsNT(strax.MergeOnlyPlugin): """ + __version__ = "0.0.0" provides = "peak_positions" depends_on = ( - "peak_positions_mlp", "peak_positions_cnf", + "peak_positions_mlp", ) save_when = strax.SaveWhen.NEVER - __version__ = "0.0.0" default_reconstruction_algorithm = straxen.URLConfig( default=DEFAULT_POSREC_ALGO, help="default reconstruction algorithm that provides (x,y)" diff --git a/straxen/plugins/peaks/peak_positions_cnf.py b/straxen/plugins/peaks/peak_positions_cnf.py index 8f3a73038..c01c91597 100644 --- a/straxen/plugins/peaks/peak_positions_cnf.py +++ b/straxen/plugins/peaks/peak_positions_cnf.py @@ -1,247 +1,19 @@ -import numpy as np import strax -import straxen -from straxen.plugins.peaks._peak_positions_base import PeakPositionsBaseNT +from ._peak_positions_base import PeakPositionsBase export, __all__ = strax.exporter() @export -class PeakPositionsCNF(PeakPositionsBaseNT): - """Conditional Normalizing Flow for position reconstruction. +class PeakPositionsCNF(PeakPositionsBase): - This plugin reconstructs the position of S2 peaks using a conditional normalizing flow model. - It provides x and y coordinates of the reconstructed position, along with uncertainty contours - and uncertainty estimates in r and theta. For information on the model, see note_. - - .. _note: https://xe1t-wiki.lngs.infn.it/doku.php?id=xenon:xenonnt:juehang:flow_posrec_proposal_sr2 # noqa: E501 - - Depends on: 'peaks' - Provides: 'peak_positions_cnf' - - Configuration options: - - min_reconstruction_area: Minimum area (PE) required for reconstruction - - n_poly: Size of the uncertainty contour - - sig: Confidence level of the contour - - log_area_scale: Scaling parameter for log area - - n_top_pmts: Number of top PMTs - - pred_function: Path to the compiled JAX function for predictions - - """ - - __version__ = "0.0.4" - depends_on = "peaks" - provides = "peak_positions_cnf" + __version__ = "0.0.0" + child_plugin = True algorithm = "cnf" - compressor = "zstd" - parallel = True - - n_poly = straxen.URLConfig( - default=16, - infer_type=False, - help="Size of uncertainty contour", - ) - - N_chunk_max = straxen.URLConfig( - default=4096, - infer_type=False, - help="Maximum size of chunk for vectorised JAX function", - ) - - sig = straxen.URLConfig( - default=0.393, - infer_type=False, - help="Confidence level of contour", - ) - - log_area_scale = straxen.URLConfig( - default=10, - infer_type=False, - help="Scaling parameter for log area", + depends_on = ( + "peaklet_positions_cnf", + "peaklet_classification", + "merged_s2s", + "merged_s2_positions_cnf", ) - - pred_function = straxen.URLConfig( - default=( - "jax://resource://flow_20240730.tar.gz?" - "n_poly=plugin.n_poly&sig=plugin.sig&fmt=abs_path" - ), - help="Compiled JAX function", - ) - - @staticmethod - def calculate_theta_diff(theta_array, avg_theta): - """Calculate the difference between maximum and minimum angles from an array of angles by - normalizing the angular difference into the range [0, 2π). - - Parameters: - theta_array : np.ndarray - A 2D numpy array where each row represents a set of angles in radians. - avg_theta : np.ndarray - A 1D numpy array representing the average angle in radians for each - row in `theta_array`. - - Returns: - theta_diff : np.ndarray - A 1D numpy array with the difference between the maximum and minimum angles in radians - for each row in `theta_array` - - """ - # Correction to handle circular nature of angles - theta_array_shift = (theta_array - avg_theta[..., np.newaxis] + np.pi) % (2 * np.pi) - theta_min = np.min(theta_array_shift, axis=1) - theta_max = np.max(theta_array_shift, axis=1) - theta_diff = theta_max - theta_min - - return theta_diff - - def infer_dtype(self): - """Define the data type for the output. - - Returns: - dtype: Numpy dtype for the output array - - """ - dtype = [ - ( - ( - f"Reconstructed {self.algorithm} S2 X position (cm), uncorrected", - f"x_{self.algorithm}", - ), - np.float32, - ), - ( - ( - f"Reconstructed {self.algorithm} S2 Y position (cm), uncorrected", - f"y_{self.algorithm}", - ), - np.float32, - ), - ( - ("Position uncertainty contour (cm)", f"position_contour_{self.algorithm}"), - np.float32, - (self.n_poly + 1, 2), - ), - (("Position uncertainty in r (cm)", f"r_uncertainty_{self.algorithm}"), np.float32), - ( - ("Position uncertainty in theta (rad)", f"theta_uncertainty_{self.algorithm}"), - np.float32, - ), - ] - dtype += strax.time_fields - return dtype - - def vectorized_prediction_chunk(self, flow_condition): - """Compute predictions for a chunk of data. - - Args: - flow_condition: Input data for the flow model - - Returns: - xy: Predicted x and y coordinates - contour: Uncertainty contours - - """ - N_entries = flow_condition.shape[0] - if N_entries > self.N_chunk_max: - raise ValueError("Chunk greater than max size") - else: - inputs = np.zeros((self.N_chunk_max, self.n_top_pmts + 1)) - inputs[:N_entries] = flow_condition - xy, contour = self.pred_function(inputs) - return xy[:N_entries], contour[:N_entries] - - def prediction_loop(self, flow_condition): - """Compute predictions for arbitrary-size inputs using a loop. - - Args: - flow_condition: Input data for the flow model - - Returns: - xy: Predicted x and y coordinates - contour: Uncertainty contours - - """ - N_entries = flow_condition.shape[0] - if N_entries <= self.N_chunk_max: - return self.vectorized_prediction_chunk(flow_condition) - N_chunks = N_entries // self.N_chunk_max - - xy_list = [] - contour_list = [] - for i in range(N_chunks): - xy, contour = self.vectorized_prediction_chunk( - flow_condition[i * self.N_chunk_max : (i + 1) * self.N_chunk_max] - ) - xy_list.append(xy) - contour_list.append(contour) - - if N_chunks * self.N_chunk_max < N_entries: - xy, contour = self.vectorized_prediction_chunk( - flow_condition[(i + 1) * self.N_chunk_max :] - ) - xy_list.append(xy) - contour_list.append(contour) - return np.concatenate(xy_list, axis=0), np.concatenate(contour_list, axis=0) - - def compute(self, peaks): - """Compute the position reconstruction for the given peaks. - - Args: - peaks: Input peak data - - Returns: - result: Array with reconstructed positions and uncertainties - - """ - # Initialize result array - result = np.ones(len(peaks), dtype=self.dtype) - result["time"], result["endtime"] = peaks["time"], strax.endtime(peaks) - - # Set default values to NaN - result[f"x_{self.algorithm}"] *= np.nan - result[f"y_{self.algorithm}"] *= np.nan - result[f"position_contour_{self.algorithm}"] *= np.nan - result[f"r_uncertainty_{self.algorithm}"] *= np.nan - result[f"theta_uncertainty_{self.algorithm}"] *= np.nan - - # Keep large peaks only - peak_mask = peaks["area"] > self.min_reconstruction_area - if not np.sum(peak_mask): - # Nothing to do, and .predict crashes on empty arrays - return result - - # Prepare input data for the flow model - area_per_channel_top = peaks["area_per_channel"][peak_mask, 0 : self.n_top_pmts] - total_top_areas = np.sum(area_per_channel_top, axis=1) - with np.errstate(divide="ignore", invalid="ignore"): - flow_data = np.concatenate( - [ - area_per_channel_top / total_top_areas[..., np.newaxis], - np.log(total_top_areas[..., np.newaxis]) / self.log_area_scale, - ], - axis=1, - ) - - # Get position reconstruction - xy, contours = self.prediction_loop(flow_data) - - # Write output to the result array - result[f"x_{self.algorithm}"][peak_mask] = xy[:, 0] - result[f"y_{self.algorithm}"][peak_mask] = xy[:, 1] - result[f"position_contour_{self.algorithm}"][peak_mask] = contours - - # Calculate uncertainties in r and theta - r_array = np.linalg.norm(contours, axis=2) - r_min = np.min(r_array, axis=1) - r_max = np.max(r_array, axis=1) - - theta_array = np.arctan2(contours[..., 1], contours[..., 0]) - - avg_theta = np.arctan2(xy[:, 1], xy[:, 0]) - - theta_diff = self.calculate_theta_diff(theta_array, avg_theta) - - result[f"r_uncertainty_{self.algorithm}"][peak_mask] = (r_max - r_min) / 2 - result[f"theta_uncertainty_{self.algorithm}"][peak_mask] = np.abs(theta_diff) / 2 - - return result + provides = "peak_positions_cnf" diff --git a/straxen/plugins/peaks/peak_positions_mlp.py b/straxen/plugins/peaks/peak_positions_mlp.py index 8cbcd9204..83c28f9f0 100644 --- a/straxen/plugins/peaks/peak_positions_mlp.py +++ b/straxen/plugins/peaks/peak_positions_mlp.py @@ -1,31 +1,19 @@ import strax -import straxen -from straxen.plugins.peaks._peak_positions_base import PeakPositionsBaseNT - +from ._peak_positions_base import PeakPositionsBase export, __all__ = strax.exporter() @export -class PeakPositionsMLP(PeakPositionsBaseNT): - """Multilayer Perceptron (MLP) neural net for position reconstruction.""" +class PeakPositionsMLP(PeakPositionsBase): - provides = "peak_positions_mlp" + __version__ = "0.0.0" + child_plugin = True algorithm = "mlp" - gc_collect_after_compute = True - - tf_model_mlp = straxen.URLConfig( - default=( - "tf://" - "resource://" - f"cmt://{algorithm}_model" - "?version=ONLINE" - "&run_id=plugin.run_id" - "&fmt=abs_path" - ), - help=( - 'MLP model. Should be opened using the "tf" descriptor. ' - 'Set to "None" to skip computation' - ), - cache=3, + depends_on = ( + "peaklet_positions_mlp", + "peaklet_classification", + "merged_s2s", + "merged_s2_positions_mlp", ) + provides = "peak_positions_mlp" diff --git a/straxen/plugins/peaks/peaks.py b/straxen/plugins/peaks/peaks.py index 9e5d1e815..e3580bc9e 100644 --- a/straxen/plugins/peaks/peaks.py +++ b/straxen/plugins/peaks/peaks.py @@ -1,3 +1,4 @@ +from typing import Tuple, Union import numpy as np import strax import straxen @@ -18,10 +19,9 @@ class Peaks(strax.Plugin): __version__ = "0.1.2" - depends_on = ("peaklets", "peaklet_classification", "merged_s2s") + depends_on: Union[Tuple[str, ...], str] = ("peaklets", "peaklet_classification", "merged_s2s") data_kind = "peaks" provides = "peaks" - parallel = True compressor = "zstd" save_when = strax.SaveWhen.EXPLICIT @@ -66,6 +66,6 @@ def compute(self, peaklets, merged_s2s): peaks["time"][to_check][1:] >= strax.endtime(peaks)[to_check][:-1] ), "Peaks not disjoint" - result = np.zeros(len(peaks), self.dtype) - strax.copy_to_buffer(peaks, result, "_copy_requested_peak_fields") + result = np.zeros(len(peaks), dtype=self.dtype) + strax.copy_to_buffer(peaks, result, f"_copy_requested_{self.provides[0]}_fields") return result diff --git a/tests/plugins/posrec_plugins.py b/tests/plugins/posrec_plugins.py index 90e0beb4d..9861b9756 100644 --- a/tests/plugins/posrec_plugins.py +++ b/tests/plugins/posrec_plugins.py @@ -1,16 +1,16 @@ """Run with python tests/plugins/posrec_processing.py.""" import os +import numpy as np import strax -import straxen from _core import PluginTestAccumulator, run_pytest_from_main -import numpy as np +from straxen.plugins.peaklets._peaklet_positions_base import PeakletPositionsBase @PluginTestAccumulator.register("test_posrec_set_path") def test_posrec_set_path( self, - target="peak_positions_mlp", + target="peaklet_positions_mlp", config_name="tf_model_mlp", field="x_mlp", ): @@ -42,7 +42,7 @@ def test_posrec_set_path( @PluginTestAccumulator.register("test_posrec_set_to_none") def test_posrec_set_to_none( self, - target="peak_positions_mlp", + target="peaklet_positions_mlp", config_name="tf_model_mlp", field="x_mlp", ): @@ -56,7 +56,7 @@ def test_posrec_set_to_none( @PluginTestAccumulator.register("test_posrec_bad_configs_raising_errors") def test_posrec_bad_configs_raising_errors( self, - target="peak_positions_mlp", + target="peaklet_positions_mlp", config_name="tf_model_mlp", ): """Test that we get the right errors when we set invalid options.""" @@ -73,8 +73,8 @@ def test_posrec_bad_configs_raising_errors( with self.assertRaises(FileNotFoundError): plugin.get_tf_model() - dummy_st.register(straxen.plugins.peak_positions_cnf.PeakPositionsBaseNT) - plugin_name = strax.camel_to_snake("PeakPositionsBaseNT") + dummy_st.register(PeakletPositionsBase) + plugin_name = strax.camel_to_snake(PeakletPositionsBase.__name__) with self.assertRaises(NotImplementedError): dummy_st.get_single_plugin(self.run_id, plugin_name) diff --git a/tests/plugins/s1_posrec_plugins.py b/tests/plugins/s1_posrec_plugins.py index dceb33877..fca6ffbfb 100644 --- a/tests/plugins/s1_posrec_plugins.py +++ b/tests/plugins/s1_posrec_plugins.py @@ -1,9 +1,9 @@ # Run with python tests/plugins/s1_posrec_plugins.py.py import os +import numpy as np import strax -import straxen from _core import PluginTestAccumulator, run_pytest_from_main -import numpy as np +from straxen.plugins.peaklets._peaklet_positions_base import PeakletPositionsBase @PluginTestAccumulator.register("test_posrec_set_path") @@ -76,8 +76,8 @@ def test_posrec_bad_configs_raising_errors( with self.assertRaises(FileNotFoundError): plugin.get_tf_model() - dummy_st.register(straxen.plugins.peak_positions_cnn.PeakPositionsBaseNT) - plugin_name = strax.camel_to_snake("PeakPositionsBaseNT") + dummy_st.register(PeakletPositionsBase) + plugin_name = strax.camel_to_snake(PeakletPositionsBase.__name__) with self.assertRaises(NotImplementedError): dummy_st.get_single_plugin(self.run_id, plugin_name)