Source code for lumin.optimisation.threshold

import pandas as pd
import numpy as np
from typing import Tuple
import warnings

from ..evaluation.ams import calc_ams
from ..plotting.plot_settings import PlotSettings

import seaborn as sns
import matplotlib.pyplot as plt

__all__ = ['binary_class_cut_by_ams']

[docs]def binary_class_cut_by_ams(df:pd.DataFrame, top_perc:float=5.0, min_pred:float=0.9, wgt_factor:float=1.0, br:float=0.0, syst_unc_b:float=0.0, pred_name:str='pred', targ_name:str='gen_target', wgt_name:str='gen_weight', plot_settings:PlotSettings=PlotSettings()) -> Tuple[float,float,float]: r''' Optimise a cut on a signal-background classifier prediction by the Approximate Median Significance Cut which should generalise better by taking the mean class prediction of the top top_perc percentage of points as ranked by AMS Arguments: df: Pandas DataFrame containing data top_perc: top percentage of events to consider as ranked by AMS min_pred: minimum prediction to consider wgt_factor: single multiplicative coeficient for rescaling signal and background weights before computing AMS br: background offset bias syst_unc_b: fractional systemtatic uncertainty on background pred_name: column to use as predictions targ_name: column to use as truth labels for signal and background wgt_name: column to use as weights for signal and background events plot_settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance Returns: Optimised cut AMS at cut Maximum AMS ''' # TODO: Multithread AMS calculation sig, bkg = (df.gen_target == 1), (df.gen_target == 0) if 'ams' not in df.columns: df['ams'] = -1 df.loc[df[pred_name] >= min_pred, 'ams'] = df[df[pred_name] >= min_pred].apply( lambda row: calc_ams(wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & sig, wgt_name]), wgt_factor*np.sum(df.loc[(df[pred_name] >= row[pred_name]) & bkg, wgt_name]), br=br, unc_b=syst_unc_b), axis=1) sort = df.sort_values(by='ams', ascending=False) cuts = sort[pred_name].values[0:int(top_perc*len(sort)/100)] cut = np.mean(cuts) ams = calc_ams(wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & sig, 'gen_weight']), wgt_factor*np.sum(sort.loc[(sort[pred_name] >= cut) & bkg, 'gen_weight']), br=br, unc_b=syst_unc_b) print(f'Mean cut at {cut} corresponds to AMS of {ams}') print(f'Maximum AMS for data is {sort.iloc[0]["ams"]} at cut of {sort.iloc[0][pred_name]}') with sns.axes_style(, sns.color_palette(plot_settings.cat_palette) as palette: plt.figure(figsize=(plot_settings.w_small, plot_settings.h_small)) sns.distplot(cuts, label=f'Top {top_perc}%') plt.axvline(x=cut, label='Mean prediction', color=palette[1]) plt.axvline(x=sort.iloc[0][pred_name], label='Max. AMS', color=palette[2]) plt.legend(loc=plot_settings.leg_loc, fontsize=plot_settings.leg_sz) plt.xticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col) plt.yticks(fontsize=plot_settings.tk_sz, color=plot_settings.tk_col) plt.xlabel('Class prediction', fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col) plt.ylabel(r"$\frac{1}{N}\ \frac{dN}{dp}$", fontsize=plot_settings.lbl_sz, color=plot_settings.lbl_col) return cut, ams, sort.iloc[0]["ams"]
Read the Docs v: v0.7.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.


Access comprehensive developer and user documentation for LUMIN

View Docs


Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials