Skip to content

Commit

Permalink
Pos-rec for merged_s2s
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Nov 18, 2024
1 parent 5b8a662 commit 486ccd1
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 51 deletions.
3 changes: 3 additions & 0 deletions straxen/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
straxen.PeakletPositionsMLP,
straxen.PeakletPositionsCNF,
straxen.PeakletPositionsNT,
straxen.MergedS2sPositionsMLP,
straxen.MergedS2sPositionsCNF,
straxen.MergedS2sPositionsNT,
straxen.MergedPeakPositionsNT,
],
check_available=("peak_basics", "event_basics"),
Expand Down
12 changes: 12 additions & 0 deletions straxen/plugins/peaks/_peak_positions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,15 @@ class PeakletPositionsBaseNT(PeakPositionsBaseNT):

def compute(self, peaklets):
return super().compute(peaklets)


@export
class MergedS2sPositionsBaseNT(PeakPositionsBaseNT):
"""Pose-rec on merged_s2s instead of peaks."""

__version__ = "0.0.0"
child_plugin = True
depends_on = "merged_s2s"

def compute(self, merged_s2s):
return super().compute(merged_s2s)
74 changes: 28 additions & 46 deletions straxen/plugins/peaks/peak_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import strax
import straxen

from straxen.plugins.defaults import DEFAULT_POSREC_ALGO, FAKE_MERGED_S2_TYPE
from straxen.plugins.defaults import DEFAULT_POSREC_ALGO
from .peaks import Peaks

export, __all__ = strax.exporter()

Expand Down Expand Up @@ -64,60 +65,41 @@ def compute(self, peaklets):


@export
class MergedPeakPositionsNT(strax.Plugin):
class MergedS2sPositionsNT(PeakPositionsNT):

__version__ = "0.0.0"
provides = "merged_s2s_positions"
depends_on = (
"merged_s2s_positions_mlp",
"merged_s2s_positions_cnf",
)

def compute(self, merged_s2s):
return super().compute(merged_s2s)


@export
class MergedPeakPositionsNT(Peaks):

depends_on = ("peaklet_positions", "peaklet_classification", "merged_s2s")
data_kind = "peaks"
__version__ = "0.0.0"
child_plugin = True
save_when = strax.SaveWhen.ALWAYS

depends_on = (
"peaklet_positions",
"peaklet_classification",
"merged_s2s",
"merged_s2s_positions",
)
provides = "peak_positions"

default_reconstruction_algorithm = straxen.URLConfig(
default=DEFAULT_POSREC_ALGO, help="default reconstruction algorithm that provides (x,y)"
)

merge_without_s1 = straxen.URLConfig(
default=True,
infer_type=False,
help=(
"If true, S1s will be igored during the merging. "
"It's now possible for a S1 to be inside a S2 post merging"
),
)

def infer_dtype(self):
dtype = self.deps["peaklet_positions"].dtype_for("peaklet_positions")
return dtype
return self.deps["peaklet_positions"].dtype_for("peaklet_positions")

def compute(self, peaklets, merged_s2s):
# Remove fake merged S2s from dirty hack, see above
merged_s2s = merged_s2s[merged_s2s["type"] != FAKE_MERGED_S2_TYPE]

if self.merge_without_s1:
is_s1_peaklets = peaklets["type"] == 1
_peaklets = peaklets[~is_s1_peaklets]
else:
_peaklets = peaklets
windows = strax.touching_windows(_peaklets, merged_s2s)

_merged_s2 = np.zeros(len(merged_s2s), dtype=peaklets.dtype)
indices = np.full(len(_peaklets), -1)

for i, (start, end) in enumerate(windows):
indices[start:end] = i
for name in peaklets.dtype.names:
_merged_s2[name][i] = np.nanmean(_peaklets[name][start:end], axis=0)

_merged_s2["time"] = merged_s2s["time"]
_merged_s2["endtime"] = strax.endtime(merged_s2s)

# TODO: We have to make sure that the sorting here is the same to in the Peaks plugin
# because maybe different peaklets can have same time
_result = strax.sort_by_time(np.concatenate([_peaklets[indices == -1], _merged_s2]))

if self.merge_without_s1:
_result = strax.sort_by_time(np.concatenate([peaklets[is_s1_peaklets], _result]))

result = np.zeros(len(_result), dtype=self.dtype)
strax.copy_to_buffer(_result, result, "_copy_requested_peak_positions_fields")
return result
_merged_s2s = strax.merge_arrs([merged_s2s], dtype=peaklets.dtype, replacing=True)
return super().compute(peaklets, _merged_s2s)
14 changes: 13 additions & 1 deletion straxen/plugins/peaks/peak_positions_cnf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import strax
import straxen
from straxen.plugins.peaks._peak_positions_base import PeakPositionsBaseNT, PeakletPositionsBaseNT
from straxen.plugins.peaks._peak_positions_base import (
PeakPositionsBaseNT,
PeakletPositionsBaseNT,
MergedS2sPositionsBaseNT,
)

export, __all__ = strax.exporter()

Expand Down Expand Up @@ -253,3 +257,11 @@ class PeakletPositionsCNF(PeakletPositionsBaseNT, PeakPositionsCNF):
provides = "peaklet_positions_cnf"
__version__ = "0.0.0"
child_plugin = True


@export
class MergedS2sPositionsCNF(MergedS2sPositionsBaseNT, PeakPositionsCNF):

provides = "merged_s2s_positions_cnf"
__version__ = "0.0.0"
child_plugin = True
14 changes: 13 additions & 1 deletion straxen/plugins/peaks/peak_positions_mlp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import strax
import straxen
from straxen.plugins.peaks._peak_positions_base import PeakPositionsBaseNT, PeakletPositionsBaseNT
from straxen.plugins.peaks._peak_positions_base import (
PeakPositionsBaseNT,
PeakletPositionsBaseNT,
MergedS2sPositionsBaseNT,
)


export, __all__ = strax.exporter()
Expand Down Expand Up @@ -37,3 +41,11 @@ class PeakletPositionsMLP(PeakletPositionsBaseNT, PeakPositionsMLP):
provides = "peaklet_positions_mlp"
__version__ = "0.0.0"
child_plugin = True


@export
class MergedS2sPositionsMLP(MergedS2sPositionsBaseNT, PeakPositionsMLP):

provides = "merged_s2s_positions_mlp"
__version__ = "0.0.0"
child_plugin = True
7 changes: 4 additions & 3 deletions straxen/plugins/peaks/peaks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Tuple, Union
import numpy as np
import strax
import straxen
Expand All @@ -18,7 +19,7 @@ class Peaks(strax.Plugin):

__version__ = "0.1.2"

depends_on = ("peaklets", "peaklet_classification", "merged_s2s")
depends_on: Union[Tuple[str, ...], str] = ("peaklets", "peaklet_classification", "merged_s2s")
data_kind = "peaks"
provides = "peaks"
parallel = True
Expand Down Expand Up @@ -66,6 +67,6 @@ def compute(self, peaklets, merged_s2s):
peaks["time"][to_check][1:] >= strax.endtime(peaks)[to_check][:-1]
), "Peaks not disjoint"

result = np.zeros(len(peaks), self.dtype)
strax.copy_to_buffer(peaks, result, "_copy_requested_peak_fields")
result = np.zeros(len(peaks), dtype=self.dtype)
strax.copy_to_buffer(peaks, result, f"_copy_requested_{self.provides[0]}_fields")
return result

0 comments on commit 486ccd1

Please sign in to comment.