From 7da330235591f98eca9ab8f8f76df74ff42853ea Mon Sep 17 00:00:00 2001 From: mhliu Date: Thu, 30 Nov 2023 13:56:55 -0500 Subject: [PATCH] Separate correction for S1PE/S2PE and cS1/cS2 using true/reconstructed position --- appletree/plugins/detector.py | 69 ++++++++++++++++++++++------- appletree/plugins/reconstruction.py | 12 ++--- 2 files changed, 58 insertions(+), 23 deletions(-) diff --git a/appletree/plugins/detector.py b/appletree/plugins/detector.py index 0051744a..d8bf8766 100644 --- a/appletree/plugins/detector.py +++ b/appletree/plugins/detector.py @@ -18,15 +18,33 @@ help="S1 light collection efficiency correction", ), ) -class S1Correction(Plugin): +class S1CorrectionTrue(Plugin): + depends_on = ["x", "y", "z"] + provides = ["s1_correction_true"] + + @partial(jit, static_argnums=(0,)) + def simulate(self, key, parameters, x, y, z): + pos_true = jnp.stack([x, y, z]).T + s1_correction_true = self.s1_correction.apply(pos_true) + return key, s1_correction_true + +@export +@takes_config( + Map( + name="s1_correction", + default="_s1_correction.json", + help="S1 light collection efficiency correction", + ), +) +class S1CorrectionRec(Plugin): depends_on = ["rec_x", "rec_y", "rec_z"] - provides = ["s1_correction"] + provides = ["s1_correction_rec"] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, rec_x, rec_y, rec_z): - pos = jnp.stack([rec_x, rec_y, rec_z]).T - s1_correction = self.s1_correction.apply(pos) - return key, s1_correction + pos_rec = jnp.stack([rec_x, rec_y, rec_z]).T + s1_correction_rec = self.s1_correction.apply(pos_rec) + return key, s1_correction_rec @export @@ -37,27 +55,44 @@ def simulate(self, key, parameters, rec_x, rec_y, rec_z): help="S2 light collection efficiency correction", ), ) -class S2Correction(Plugin): +class S2CorrectionTrue(Plugin): + depends_on = ["x", "y"] + provides = ["s2_correction_true"] + + @partial(jit, static_argnums=(0,)) + def simulate(self, key, parameters, x, y): + pos_true = jnp.stack([x, y]).T + s2_correction_true = self.s2_correction.apply(pos_true) + return key, s2_correction_true + +@export +@takes_config( + Map( + name="s2_correction", + default="_s2_correction.json", + help="S2 light collection efficiency correction", + ), +) +class S2CorrectionRec(Plugin): depends_on = ["rec_x", "rec_y"] - provides = ["s2_correction"] + provides = ["s2_correction_rec"] @partial(jit, static_argnums=(0,)) def simulate(self, key, parameters, rec_x, rec_y): - pos = jnp.stack([rec_x, rec_y]).T - s2_correction = self.s2_correction.apply(pos) - return key, s2_correction - + pos_rec = jnp.stack([rec_x, rec_y]).T + s2_correction_rec = self.s2_correction.apply(pos_rec) + return key, s2_correction_rec @export class PhotonDetection(Plugin): - depends_on = ["num_photon", "s1_correction"] + depends_on = ["num_photon", "s1_correction_true"] provides = ["num_s1_phd"] parameters = ("g1", "p_dpe") @partial(jit, static_argnums=(0,)) - def simulate(self, key, parameters, num_photon, s1_correction): + def simulate(self, key, parameters, num_photon, s1_correction_true): g1_true_no_dpe = jnp.clip( - parameters["g1"] * s1_correction / (1.0 + parameters["p_dpe"]), 0, 1.0 + parameters["g1"] * s1_correction_true / (1.0 + parameters["p_dpe"]), 0, 1.0 ) key, num_s1_phd = randgen.binomial(key, g1_true_no_dpe, num_photon) return key, num_s1_phd @@ -106,14 +141,14 @@ def simulate(self, key, parameters, num_electron, drift_survive_prob): @export class S2PE(Plugin): - depends_on = ["num_electron_drifted", "s2_correction"] + depends_on = ["num_electron_drifted", "s2_correction_true"] provides = ["num_s2_pe"] parameters = ("g2", "gas_gain") @partial(jit, static_argnums=(0,)) - def simulate(self, key, parameters, num_electron_drifted, s2_correction): + def simulate(self, key, parameters, num_electron_drifted, s2_correction_true): extraction_eff = parameters["g2"] / parameters["gas_gain"] - g2_true = parameters["g2"] * s2_correction + g2_true = parameters["g2"] * s2_correction_true gas_gain_true = g2_true / extraction_eff key, num_electron_extracted = randgen.binomial(key, extraction_eff, num_electron_drifted) diff --git a/appletree/plugins/reconstruction.py b/appletree/plugins/reconstruction.py index 47c6ac0b..c7aa02a0 100644 --- a/appletree/plugins/reconstruction.py +++ b/appletree/plugins/reconstruction.py @@ -71,21 +71,21 @@ def simulate(self, key, parameters, num_s2_pe): @export class cS1(Plugin): - depends_on = ["s1_area", "s1_correction"] + depends_on = ["s1_area", "s1_correction_rec"] provides = ["cs1"] @partial(jit, static_argnums=(0,)) - def simulate(self, key, parameters, s1_area, s1_correction): - cs1 = s1_area / s1_correction + def simulate(self, key, parameters, s1_area, s1_correction_rec): + cs1 = s1_area / s1_correction_rec return key, cs1 @export class cS2(Plugin): - depends_on = ["s2_area", "s2_correction", "drift_survive_prob"] + depends_on = ["s2_area", "s2_correction_rec", "drift_survive_prob"] provides = ["cs2"] @partial(jit, static_argnums=(0,)) - def simulate(self, key, parameters, s2_area, s2_correction, drift_survive_prob): - cs2 = s2_area / s2_correction / drift_survive_prob + def simulate(self, key, parameters, s2_area, s2_correction_rec, drift_survive_prob): + cs2 = s2_area / s2_correction_rec / drift_survive_prob return key, cs2