Source code for lumin.optimisation.threshold
from typing import Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from ..evaluation.ams import calc_ams
from ..plotting.plot_settings import PlotSettings
__all__ = ["binary_class_cut_by_ams"]
[docs]def binary_class_cut_by_ams(
df: pd.DataFrame,
top_perc: float = 5.0,
min_pred: float = 0.9,
wgt_factor: float = 1.0,
br: float = 0.0,
syst_unc_b: float = 0.0,
pred_name: str = "pred",
targ_name: str = "gen_target",
wgt_name: str = "gen_weight",
plot_settings: PlotSettings = PlotSettings(),
) -> Tuple[float, float, float]:
r"""
Optimise a cut on a signal-background classifier prediction by the Approximate Median Significance
Cut which should generalise better by taking the mean class prediction of the top top_perc percentage of points as ranked by AMS
Arguments:
df: Pandas DataFrame containing data
top_perc: top percentage of events to consider as ranked by AMS
min_pred: minimum prediction to consider
wgt_factor: single multiplicative coeficient for rescaling signal and background weights before computing AMS
br: background offset bias
syst_unc_b: fractional systemtatic uncertainty on background
pred_name: column to use as predictions
targ_name: column to use as truth labels for signal and background
wgt_name: column to use as weights for signal and background events
plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
Returns:
Optimised cut
AMS at cut
Maximum AMS
"""
# TODO: Multithread AMS calculation
sig, bkg = (df.gen_target == 1), (df.gen_target == 0)
if "ams" not in df.columns:
df["ams"] = -1
df.loc[df[pred_name] >= min_pred, "ams"] = df[df[pred_name] >= min_pred].apply(
lambda row: calc_ams(
wgt_factor * np.sum(df.loc[(df[pred_name] >= row[pred_name]) & sig, wgt_name]),
wgt_factor * np.sum(df.loc[(df[pred_name] >= row[pred_name]) & bkg, wgt_name]),
br=br,
unc_b=syst_unc_b,
),
axis=1,
)
sort = df.sort_values(by="ams", ascending=False)
cuts = sort[pred_name].values[0 : int(top_perc * len(sort) / 100)]
cut = np.mean(cuts)
ams = calc_ams(
wgt_factor * np.sum(sort.loc[(sort[pred_name] >= cut) & sig, "gen_weight"]),
wgt_factor * np.sum(sort.loc[(sort[pred_name] >= cut) & bkg, "gen_weight"]),
br=br,
unc_b=syst_unc_b,
)
print(f"Mean cut at {cut} corresponds to AMS of {ams}")
print(f'Maximum AMS for data is {sort.iloc[0]["ams"]} at cut of {sort.iloc[0][pred_name]}')
with sns.axes_style(plot_settings.style), sns.color_palette(plot_settings.cat_palette) as palette:
plt.figure(figsize=(plot_settings.w_small, plot_settings.h_small))
sns.distplot(cuts, label=f"Top {top_perc}%")
plt.axvline(x=cut, label="Mean prediction", color=palette[1])
plt.axvline(x=sort.iloc[0][pred_name], label="Max. AMS", color=palette[2])
plt.legend(loc=plot_settings.leg_loc, fontsize=plot_settings.leg_sz)
plt.xticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col)
plt.yticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col)
plt.xlabel("Class prediction", fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col)
plt.ylabel(r"$\frac{1}{N}\ \frac{dN}{dp}$", fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col)
plt.show()
return cut, ams, sort.iloc[0]["ams"]