Skip to content

Commit

Permalink
Calculate n_competing only of type 0, 1, 2
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Dec 10, 2024
1 parent 715d888 commit c954cff
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions straxen/plugins/peaks/peak_proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class PeakProximity(strax.OverlapWindowPlugin):
"""Look for peaks around a peak to determine how many peaks are in proximity (in time) of a
peak."""

__version__ = "0.4.0"
__version__ = "0.5.0"

depends_on = "peak_basics"
dtype = [
Expand Down Expand Up @@ -56,9 +56,12 @@ def get_window_size(self):
return self.peak_max_proximity_time

def compute(self, peaks):
windows = strax.touching_windows(peaks, peaks, window=self.nearby_window)
# have to consider the peak with type 20
n_left, n_tot = self.find_n_competing(peaks, windows, fraction=self.min_area_fraction)
# later we can even remove type 0
mask = np.isin(peaks["type"], [0, 1, 2])
windows = strax.touching_windows(peaks[mask], peaks, window=self.nearby_window)
n_left, n_tot = self.find_n_competing(
peaks[mask], peaks, windows, fraction=self.min_area_fraction
)

t_to_prev_peak = np.ones(len(peaks), dtype=np.int64) * self.peak_max_proximity_time
t_to_prev_peak[1:] = peaks["time"][1:] - peaks["endtime"][:-1]
Expand All @@ -78,15 +81,17 @@ def compute(self, peaks):

@staticmethod
@numba.jit(nopython=True, nogil=True, cache=True)
def find_n_competing(peaks, windows, fraction):
def find_n_competing(peaks_in_roi, peaks, windows, fraction):
n_left = np.zeros(len(peaks), dtype=np.int32)
n_tot = n_left.copy()
areas_in_roi = peaks_in_roi["area"]
areas = peaks["area"]

dig = np.searchsorted(peaks_in_roi["center_time"], peaks["center_time"])
for i, peak in enumerate(peaks):
left_i, right_i = windows[i]
threshold = areas[i] * fraction
n_left[i] = np.sum(areas[left_i:i] > threshold)
n_tot[i] = n_left[i] + np.sum(areas[i + 1 : right_i] > threshold)
n_left[i] = np.sum(areas_in_roi[left_i : dig[i]] > threshold)
n_tot[i] = n_left[i] + np.sum(areas_in_roi[dig[i] : right_i] > threshold)

return n_left, n_tot

0 comments on commit c954cff

Please sign in to comment.