diff --git a/saltax/match/utils.py b/saltax/match/utils.py index 71a40eb..9783a51 100644 --- a/saltax/match/utils.py +++ b/saltax/match/utils.py @@ -671,13 +671,6 @@ def get_cut_eff( else: raise NotImplementedError - if indv_cut_type == "n_minus_1": - cut_func = apply_n_minus_1_cuts - elif indv_cut_type == "single": - cut_func = apply_single_cut - else: - raise NotImplementedError - result_dict = {} for cut in all_cut_list: result_dict[cut] = np.zeros(n_bins - 1) @@ -698,13 +691,30 @@ def get_cut_eff( result_dict["all_cuts"][i] = len(selected_events_all_cut) / len(selected_events) result_dict["all_cuts_upper"][i] = interval.high result_dict["all_cuts_lower"][i] = interval.low + for cut_oi in all_cut_list: - selected_events_cut_oi = selected_events[apply_single_cut(selected_events, cut_oi)] - # Efficiency curves with Clopper-Pearson uncertainty estimation - interval = binomtest(len(selected_events_cut_oi), len(selected_events)).proportion_ci() - result_dict[cut_oi][i] = len(selected_events_cut_oi) / len(selected_events) - result_dict[cut_oi + "_upper"][i] = interval.high - result_dict[cut_oi + "_lower"][i] = interval.low + if indv_cut_type == "n_minus_1": + selected_events_cut_oi = selected_events[ + apply_n_minus_1_cuts(selected_events, cut_oi, all_cut_list) + ] + # Efficiency curves with Clopper-Pearson uncertainty estimation + interval = binomtest( + len(selected_events_all_cut), len(selected_events_cut_oi) + ).proportion_ci() + result_dict[cut_oi][i] = len(selected_events_all_cut) / len(selected_events_cut_oi) + result_dict[cut_oi + "_upper"][i] = interval.high + result_dict[cut_oi + "_lower"][i] = interval.low + elif indv_cut_type == "single": + selected_events_cut_oi = selected_events[apply_single_cut(selected_events, cut_oi)] + # Efficiency curves with Clopper-Pearson uncertainty estimation + interval = binomtest( + len(selected_events_cut_oi), len(selected_events) + ).proportion_ci() + result_dict[cut_oi][i] = len(selected_events_cut_oi) / len(selected_events) + result_dict[cut_oi + "_upper"][i] = interval.high + result_dict[cut_oi + "_lower"][i] = interval.low + else: + raise NotImplementedError if plot: colors = plt.cm.rainbow(