-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prototype of peaklets-level (x, y) S2 position reconstruction (#1482)
* Prototype of peaklets-level pos-rec * Simplify dependency tree * Add TODO * Pos-rec for merged_s2s * Recover all possible data_types * No 'NT' * Debug * Restructure the dependency tree with straxen style * Still keep CNN and GCN * Debug * Debug * Debug * Debug * Renaming * Debug * Remove CNN and GCN * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * A minor change --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
081af5d
commit 42908b0
Showing
25 changed files
with
520 additions
and
389 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,8 @@ | ||
from . import merged_s2s | ||
from .merged_s2s import * | ||
|
||
from . import merged_s2_positions_cnf | ||
from .merged_s2_positions_cnf import * | ||
|
||
from . import merged_s2_positions_mlp | ||
from .merged_s2_positions_mlp import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import strax | ||
from straxen.plugins.peaks._peak_positions_base import PeakPositionsBase | ||
|
||
export, __all__ = strax.exporter() | ||
|
||
|
||
@export | ||
class MergedS2Positions(PeakPositionsBase): | ||
|
||
__version__ = "0.0.0" | ||
child_plugin = True | ||
provides = "merged_s2_positions" | ||
depends_on = ( | ||
"merged_s2_positions_mlp", | ||
"merged_s2_positions_cnf", | ||
) | ||
|
||
def compute(self, merged_s2s): | ||
return super().compute(merged_s2s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import strax | ||
from straxen.plugins.peaklets.peaklet_positions_cnf import PeakletPositionsCNF | ||
|
||
export, __all__ = strax.exporter() | ||
|
||
|
||
@export | ||
class MergedS2PositionsCNF(PeakletPositionsCNF): | ||
|
||
__version__ = "0.0.0" | ||
child_plugin = True | ||
algorithm = "cnf" | ||
depends_on = "merged_s2s" | ||
provides = "merged_s2_positions_cnf" | ||
|
||
def compute(self, merged_s2s): | ||
return super().compute(merged_s2s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import strax | ||
from straxen.plugins.peaklets.peaklet_positions_mlp import PeakletPositionsMLP | ||
|
||
export, __all__ = strax.exporter() | ||
|
||
|
||
@export | ||
class MergedS2PositionsMLP(PeakletPositionsMLP): | ||
|
||
__version__ = "0.0.0" | ||
child_plugin = True | ||
algorithm = "mlp" | ||
depends_on = "merged_s2s" | ||
provides = "merged_s2_positions_mlp" | ||
|
||
def compute(self, merged_s2s): | ||
return super().compute(merged_s2s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Optional | ||
from warnings import warn | ||
|
||
import numpy as np | ||
import strax | ||
import straxen | ||
|
||
export, __all__ = strax.exporter() | ||
|
||
|
||
@export | ||
class PeakletPositionsBase(strax.Plugin): | ||
"""Base class for reconstructions. | ||
This class should only be used when subclassed for the different algorithms. Provides | ||
x_algorithm, y_algorithm for all peaklets > than min_reconstruction_area based on the top array. | ||
""" | ||
|
||
__version__ = "0.0.0" | ||
|
||
depends_on = "peaklets" | ||
algorithm: Optional[str] = None | ||
compressor = "zstd" | ||
|
||
min_reconstruction_area = straxen.URLConfig( | ||
help="Skip reconstruction if area (PE) is less than this", | ||
default=0, | ||
infer_type=False, | ||
) | ||
|
||
n_top_pmts = straxen.URLConfig( | ||
default=straxen.n_top_pmts, infer_type=False, help="Number of top PMTs" | ||
) | ||
|
||
def infer_dtype(self): | ||
if self.algorithm is None: | ||
raise NotImplementedError( | ||
f"Base class should not be used without algorithm as done in {__class__.__name__}" | ||
) | ||
dtype = [ | ||
( | ||
"x_" + self.algorithm, | ||
np.float32, | ||
f"Reconstructed {self.algorithm} S2 X position (cm), uncorrected", | ||
), | ||
( | ||
"y_" + self.algorithm, | ||
np.float32, | ||
f"Reconstructed {self.algorithm} S2 Y position (cm), uncorrected", | ||
), | ||
] | ||
dtype += strax.time_fields | ||
return dtype | ||
|
||
def get_tf_model(self): | ||
"""Simple wrapper to have several tf_model_mlp, tf_model_cnf, .. | ||
point to this same function in the compute method | ||
""" | ||
model = getattr(self, f"tf_model_{self.algorithm}", None) | ||
if model is None: | ||
warn( | ||
f"Setting model to None for {self.__class__.__name__} will " | ||
f"set only nans as output for {self.algorithm}" | ||
) | ||
if isinstance(model, str): | ||
raise ValueError( | ||
f"open files from tf:// protocol! Got {model} " | ||
"instead, see tests/test_posrec.py for examples." | ||
) | ||
return model | ||
|
||
def compute(self, peaklets): | ||
result = np.ones(len(peaklets), dtype=self.dtype) | ||
result["time"], result["endtime"] = peaklets["time"], strax.endtime(peaklets) | ||
|
||
result["x_" + self.algorithm] *= np.nan | ||
result["y_" + self.algorithm] *= np.nan | ||
model = self.get_tf_model() | ||
|
||
if model is None: | ||
# This plugin is disabled since no model is provided | ||
return result | ||
|
||
# Keep large peaklets only | ||
peak_mask = peaklets["area"] > self.min_reconstruction_area | ||
if not np.sum(peak_mask): | ||
# Nothing to do, and .predict crashes on empty arrays | ||
return result | ||
|
||
# Getting actual position reconstruction | ||
area_per_channel_top = peaklets["area_per_channel"][peak_mask, 0 : self.n_top_pmts] | ||
with np.errstate(divide="ignore", invalid="ignore"): | ||
area_per_channel_top = area_per_channel_top / np.max( | ||
area_per_channel_top, axis=1 | ||
).reshape(-1, 1) | ||
area_per_channel_top = area_per_channel_top.reshape(-1, self.n_top_pmts) | ||
output = model.predict(area_per_channel_top, verbose=0) | ||
|
||
# writing output to the result | ||
result["x_" + self.algorithm][peak_mask] = output[:, 0] | ||
result["y_" + self.algorithm][peak_mask] = output[:, 1] | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import strax | ||
from straxen.plugins.peaks._peak_positions_base import PeakPositionsBase | ||
|
||
export, __all__ = strax.exporter() | ||
|
||
|
||
@export | ||
class PeakletPositions(PeakPositionsBase): | ||
|
||
__version__ = "0.0.0" | ||
child_plugin = True | ||
provides = "peaklet_positions" | ||
depends_on = ( | ||
"peaklet_positions_cnf", | ||
"peaklet_positions_mlp", | ||
) | ||
|
||
def compute(self, peaklets): | ||
return super().compute(peaklets) |
Oops, something went wrong.