Source code for lumin.optimisation.threshold

import pandas as pd
import numpy as np
from typing import Tuple
import warnings

from ..evaluation.ams import calc_ams
from ..plotting.plot_settings import PlotSettings

import seaborn as sns
import matplotlib.pyplot as plt

__all__ = ['binary_class_cut_by_ams']

[docs]def binary_class_cut_by_ams(df:pd.DataFrame, top_perc:float=5.0, min_pred:float=0.9, wgt_factor:float=1.0, br:float=0.0, syst_unc_b:float=0.0, pred_name:str='pred', targ_name:str='gen_target', wgt_name:str='gen_weight', plot_settings:PlotSettings=PlotSettings()) -> Tuple[float,float,float]: r''' Optimise a cut on a signal-background classifier prediction by the Approximate Median Significance Cut which should generalise better by taking the mean class prediction of the top top_perc percentage of points as ranked by AMS Arguments: df: Pandas DataFrame containing data top_perc: top percentage of events to consider as ranked by AMS min_pred: minimum prediction to consider wgt_factor: single multiplicative coeficient for rescaling signal and background weights before computing AMS br: background offset bias syst_unc_b: fractional systemtatic uncertainty on background pred_name: column to use as predictions targ_name: column to use as truth labels for signal and background wgt_name: column to use as weights for signal and background events plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance Returns: Optimised cut AMS at cut Maximum AMS ''' # TODO: Multithread AMS calculation sig, bkg = (df.gen_target == 1), (df.gen_target == 0) if 'ams' not in df.columns: df['ams'] = -1 df.loc[df[pred_name] >= min_pred, 'ams'] = df[df[pred_name] >= min_pred].apply( lambda row: calc_ams(wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & sig, wgt_name]), wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & bkg, wgt_name]), br=br, unc_b=syst_unc_b), axis=1) sort = df.sort_values(by='ams', ascending=False) cuts = sort[pred_name].values[0:int(top_perc*len(sort)/100)] cut = np.mean(cuts) ams = calc_ams(wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & sig, 'gen_weight']), wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & bkg, 'gen_weight']), br=br, unc_b=syst_unc_b) print(f'Mean cut at {cut} corresponds to AMS of {ams}') print(f'Maximum AMS for data is {sort.iloc[0]["ams"]} at cut of {sort.iloc[0][pred_name]}') with sns.axes_style(, sns.color_palette(plot_settings.cat_palette) as palette: plt.figure(figsize=(plot_settings.w_small, plot_settings.h_small)) sns.distplot(cuts, label=f'Top {top_perc}%') plt.axvline(x=cut, label='Mean prediction', color=palette[1]) plt.axvline(x=sort.iloc[0][pred_name], label='Max. AMS', color=palette[2]) plt.legend(loc=plot_settings.leg_loc, fontsize=plot_settings.leg_sz) plt.xticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col) plt.yticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col) plt.xlabel('Class prediction', fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col) plt.ylabel(r"$\frac{1}{N}\ \frac{dN}{dp}$", fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col) return cut, ams, sort.iloc[0]["ams"]
