Skip to content

Commit

Permalink
Prototype of peaklets-level (x, y) S2 position reconstruction (#1482)
Browse files Browse the repository at this point in the history
* Prototype of peaklets-level pos-rec

* Simplify dependency tree

* Add TODO

* Pos-rec for merged_s2s

* Recover all possible data_types

* No 'NT'

* Debug

* Restructure the dependency tree with straxen style

* Still keep CNN and GCN

* Debug

* Debug

* Debug

* Debug

* Renaming

* Debug

* Remove CNN and GCN

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* A minor change

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dachengx and pre-commit-ci[bot] authored Nov 21, 2024
1 parent 081af5d commit 42908b0
Show file tree
Hide file tree
Showing 25 changed files with 520 additions and 389 deletions.
2 changes: 2 additions & 0 deletions straxen/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
straxen.EnergyEstimates,
straxen.EventInfoDouble,
straxen.DistinctChannels,
straxen.PeakPositionsMLP,
straxen.PeakPositionsCNF,
],
check_available=("peak_basics", "event_basics"),
store_run_fields=("name", "number", "start", "end", "livetime", "mode", "source"),
Expand Down
1 change: 0 additions & 1 deletion straxen/plugins/events/_event_s1_positions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class EventS1PositionBase(strax.Plugin):

algorithm: Optional[str] = None
compressor = "zstd"
parallel = True

min_s1_area_s1_posrec = straxen.URLConfig(
help="Skip reconstruction if area (PE) is less than this",
Expand Down
3 changes: 1 addition & 2 deletions straxen/plugins/events/_event_s2_positions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ class EventS2PositionBase(strax.Plugin):

algorithm: Optional[str] = None
compressor = "zstd"
parallel = True

min_reconstruction_area = straxen.URLConfig(
help="Skip reconstruction if area (PE) is less than this",
default=10,
default=0,
infer_type=False,
)
n_top_pmts = straxen.URLConfig(
Expand Down
4 changes: 2 additions & 2 deletions straxen/plugins/events/event_position_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import strax
import straxen
from straxen.plugins.defaults import DEFAULT_POSREC_ALGO
from straxen.plugins.peaks.peak_positions_cnf import PeakPositionsCNF
from straxen.plugins.peaklets.peaklet_positions_cnf import PeakletPositionsCNF


export, __all__ = strax.exporter()
Expand Down Expand Up @@ -305,7 +305,7 @@ def compute(self, events):

avg_theta = np.arctan2(events[f"{type_}_y_cnf"], events[f"{type_}_x_cnf"])

theta_diff = PeakPositionsCNF.calculate_theta_diff(theta_array, avg_theta)
theta_diff = PeakletPositionsCNF.calculate_theta_diff(theta_array, avg_theta)

# Store uncertainties
result[f"{type_}_r_position_uncertainty"] = (r_max - r_min) / 2
Expand Down
1 change: 0 additions & 1 deletion straxen/plugins/events/s2_recon_pos_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class S2ReconPosDiff(strax.Plugin):

__version__ = "0.0.3"

parallel = True
depends_on = "event_basics"
provides = "s2_recon_pos_diff"
save_when = strax.SaveWhen.EXPLICIT
Expand Down
6 changes: 6 additions & 0 deletions straxen/plugins/merged_s2s/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from . import merged_s2s
from .merged_s2s import *

from . import merged_s2_positions_cnf
from .merged_s2_positions_cnf import *

from . import merged_s2_positions_mlp
from .merged_s2_positions_mlp import *
19 changes: 19 additions & 0 deletions straxen/plugins/merged_s2s/merged_s2_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import strax
from straxen.plugins.peaks._peak_positions_base import PeakPositionsBase

export, __all__ = strax.exporter()


@export
class MergedS2Positions(PeakPositionsBase):

__version__ = "0.0.0"
child_plugin = True
provides = "merged_s2_positions"
depends_on = (
"merged_s2_positions_mlp",
"merged_s2_positions_cnf",
)

def compute(self, merged_s2s):
return super().compute(merged_s2s)
17 changes: 17 additions & 0 deletions straxen/plugins/merged_s2s/merged_s2_positions_cnf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import strax
from straxen.plugins.peaklets.peaklet_positions_cnf import PeakletPositionsCNF

export, __all__ = strax.exporter()


@export
class MergedS2PositionsCNF(PeakletPositionsCNF):

__version__ = "0.0.0"
child_plugin = True
algorithm = "cnf"
depends_on = "merged_s2s"
provides = "merged_s2_positions_cnf"

def compute(self, merged_s2s):
return super().compute(merged_s2s)
17 changes: 17 additions & 0 deletions straxen/plugins/merged_s2s/merged_s2_positions_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import strax
from straxen.plugins.peaklets.peaklet_positions_mlp import PeakletPositionsMLP

export, __all__ = strax.exporter()


@export
class MergedS2PositionsMLP(PeakletPositionsMLP):

__version__ = "0.0.0"
child_plugin = True
algorithm = "mlp"
depends_on = "merged_s2s"
provides = "merged_s2_positions_mlp"

def compute(self, merged_s2s):
return super().compute(merged_s2s)
6 changes: 6 additions & 0 deletions straxen/plugins/peaklets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@

from . import peaklet_classification_som
from .peaklet_classification_som import *

from . import peaklet_positions_cnf
from .peaklet_positions_cnf import *

from . import peaklet_positions_mlp
from .peaklet_positions_mlp import *
105 changes: 105 additions & 0 deletions straxen/plugins/peaklets/_peaklet_positions_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Optional
from warnings import warn

import numpy as np
import strax
import straxen

export, __all__ = strax.exporter()


@export
class PeakletPositionsBase(strax.Plugin):
"""Base class for reconstructions.
This class should only be used when subclassed for the different algorithms. Provides
x_algorithm, y_algorithm for all peaklets > than min_reconstruction_area based on the top array.
"""

__version__ = "0.0.0"

depends_on = "peaklets"
algorithm: Optional[str] = None
compressor = "zstd"

min_reconstruction_area = straxen.URLConfig(
help="Skip reconstruction if area (PE) is less than this",
default=0,
infer_type=False,
)

n_top_pmts = straxen.URLConfig(
default=straxen.n_top_pmts, infer_type=False, help="Number of top PMTs"
)

def infer_dtype(self):
if self.algorithm is None:
raise NotImplementedError(
f"Base class should not be used without algorithm as done in {__class__.__name__}"
)
dtype = [
(
"x_" + self.algorithm,
np.float32,
f"Reconstructed {self.algorithm} S2 X position (cm), uncorrected",
),
(
"y_" + self.algorithm,
np.float32,
f"Reconstructed {self.algorithm} S2 Y position (cm), uncorrected",
),
]
dtype += strax.time_fields
return dtype

def get_tf_model(self):
"""Simple wrapper to have several tf_model_mlp, tf_model_cnf, ..
point to this same function in the compute method
"""
model = getattr(self, f"tf_model_{self.algorithm}", None)
if model is None:
warn(
f"Setting model to None for {self.__class__.__name__} will "
f"set only nans as output for {self.algorithm}"
)
if isinstance(model, str):
raise ValueError(
f"open files from tf:// protocol! Got {model} "
"instead, see tests/test_posrec.py for examples."
)
return model

def compute(self, peaklets):
result = np.ones(len(peaklets), dtype=self.dtype)
result["time"], result["endtime"] = peaklets["time"], strax.endtime(peaklets)

result["x_" + self.algorithm] *= np.nan
result["y_" + self.algorithm] *= np.nan
model = self.get_tf_model()

if model is None:
# This plugin is disabled since no model is provided
return result

# Keep large peaklets only
peak_mask = peaklets["area"] > self.min_reconstruction_area
if not np.sum(peak_mask):
# Nothing to do, and .predict crashes on empty arrays
return result

# Getting actual position reconstruction
area_per_channel_top = peaklets["area_per_channel"][peak_mask, 0 : self.n_top_pmts]
with np.errstate(divide="ignore", invalid="ignore"):
area_per_channel_top = area_per_channel_top / np.max(
area_per_channel_top, axis=1
).reshape(-1, 1)
area_per_channel_top = area_per_channel_top.reshape(-1, self.n_top_pmts)
output = model.predict(area_per_channel_top, verbose=0)

# writing output to the result
result["x_" + self.algorithm][peak_mask] = output[:, 0]
result["y_" + self.algorithm][peak_mask] = output[:, 1]
return result
1 change: 0 additions & 1 deletion straxen/plugins/peaklets/peaklet_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class PeakletClassification(strax.Plugin):

provides: Union[str, tuple] = "peaklet_classification"
depends_on = "peaklets"
parallel = True
dtype = strax.peak_interval_dtype + [("type", np.int8, "Classification of the peak(let)")]

s1_risetime_area_parameters = straxen.URLConfig(
Expand Down
19 changes: 19 additions & 0 deletions straxen/plugins/peaklets/peaklet_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import strax
from straxen.plugins.peaks._peak_positions_base import PeakPositionsBase

export, __all__ = strax.exporter()


@export
class PeakletPositions(PeakPositionsBase):

__version__ = "0.0.0"
child_plugin = True
provides = "peaklet_positions"
depends_on = (
"peaklet_positions_cnf",
"peaklet_positions_mlp",
)

def compute(self, peaklets):
return super().compute(peaklets)
Loading

0 comments on commit 42908b0

Please sign in to comment.