Shortcuts

Source code for lumin.nn.callbacks.adversarial_callbacks

from typing import List, Optional, Callable
from fastcore.all import is_listy, store_attr
import timeit

from torch import nn, Tensor

from .callback import Callback
from .data_callbacks import TargReplace
from ..models.model_builder import ModelBuilder
from ..models.model import Model
from ...utils.misc import is_partially

__all__ = ['PivotTraining']


[docs]class PivotTraining(Callback): r''' Callback implementation of "Learning to Pivot with Adversarial Networks" (Louppe, Kagan, & Cranmer, 2016) (https://papers.nips.cc/paper/2017/hash/48ab2f9b45957ab574cf005eb8a76760-Abstract.html). The default target data in the :class:`~lumin.nn.data.fold_yielder.FoldYielder` should be the target data for the main model, and it should contain additional columns for target data for the adversary (names should be passed to the `adv_targets` argument.) Once training begins, both the main model and the adversary will be pretrained in isolation. Further training of the main model then starts, with the frozen adversary providing a bonus to the loss value if the adversary cannot predict well its targets based on the prediction of the main model. At a set interval (multiples of per batch/fold/epoch), the adversary is refined for 1 epoch with the main model frozen (if per batch, this can take a long time with no progression indicated to the user). States of the model and the adversary are saved to the savepath after both pretraining and further training. Arguments: n_pretrain_main: number of epochs to pretrain the main model n_pretrain_adv: number of epochs to pretrain the adversary adv_coef: relative weighting for the adversarial bonus (lambda in the paper), code assumes a positive value and subtracts adversarial loss from the main loss adv_model_builder: :class:`~lumin.nn.models.model_builder.ModelBuilder` defining the adversary (note that this should not define main_model+adversary) adv_targets: list of column names in foldfile to use as targets for the adversary adv_update_freq: sets how often the adversary is refined (e.g. once every `adv_update_freq` ticks) adv_update_on:str defines the tick for refining the adversary, can be batch, fold, or epoch. The paper refines once for every batch of training data. main_pretrain_cb_partials: Optional list of partial callbacks to use when pretraining the main model adv_pretrain_cb_partials: Optional list of partial callbacks to use when pretraining the adversary model adv_train_cb_partials: Optional list of partial callbacks to use when refining the adversary model ''' def __init__(self, n_pretrain_main:int, n_pretrain_adv:int, adv_coef:float, adv_model_builder:ModelBuilder, adv_targets:List[str], adv_update_freq:int, adv_update_on:str, main_pretrain_cb_partials:Optional[List[Callable[[],Callback]]]=None, adv_pretrain_cb_partials:Optional[List[Callable[[],Callback]]]=None, adv_train_cb_partials:Optional[List[Callable[[],Callback]]]=None): store_attr(but='adv_update_on') adv_update_on = adv_update_on.lower() if adv_update_on not in ['batch','fold','epoch']: raise ValueError("adv_update_on must be one of ['batch','fold','epoch']") self.adv_update_on = adv_update_on if not is_listy(self.adv_targets): adv_targets = [adv_targets] if self.main_pretrain_cb_partials is None: self.main_pretrain_cb_partials = [] if not is_listy(self.main_pretrain_cb_partials): self.main_pretrain_cb_partials = [self.main_pretrain_cb_partials] if self.adv_pretrain_cb_partials is None: self.adv_pretrain_cb_partials = [] if not is_listy(self.adv_pretrain_cb_partials): self.adv_pretrain_cb_partials = [self.adv_pretrain_cb_partials] if self.adv_train_cb_partials is None: self.adv_train_cb_partials = [] if not is_listy(self.adv_train_cb_partials): self.adv_train_cb_partials = [self.adv_train_cb_partials]
[docs] def on_train_begin(self) -> None: r''' Pretrains main model and adversary, then prepares for further training. Adds prepends training callbacks with a :class:`~lumin.nn.callbacks.data_callbacks.TargReplace` instance to grab both the target and pivot data ''' super().on_train_begin() for c in self.model.fit_params.cbs: if isinstance(c, TargReplace): return # Don't run again (on_train_begin prepends callback to cbs) # Pretrain models print("Pretraining main model") main = Model(self.model.model_builder) cbs = [] for c in self.main_pretrain_cb_partials: cbs.append(c()) model_tmr = timeit.default_timer() main.fit(n_epochs=self.n_pretrain_main, fy=self.model.fit_params.fy, bs=self.model.fit_params.bs, bulk_move=self.model.fit_params.bulk_move, train_on_weights=self.model.fit_params.train_on_weights, trn_idxs=self.model.fit_params.trn_idxs, cbs=cbs, cb_savepath=self.model.fit_params.cb_savepath) print(f"pretraining main model took {timeit.default_timer()-model_tmr:.3f}s\n") main.save(self.model.fit_params.cb_savepath/'pretrain_main.h5') self.model.set_weights(main.get_weights()) print("Pretraining adversary") self.adv = Model(self.adv_model_builder) self.adv.model = nn.Sequential(self.model.model,self.adv.model) self.adv.opt = self.adv_model_builder._build_opt(self.adv.model) self.model.freeze_layers() cbs = [TargReplace(self.adv_targets)] for c in self.adv_pretrain_cb_partials: cbs.append(c()) model_tmr = timeit.default_timer() self.adv.fit(n_epochs=self.n_pretrain_adv, fy=self.model.fit_params.fy, bs=self.model.fit_params.bs, bulk_move=self.model.fit_params.bulk_move, train_on_weights=self.model.fit_params.train_on_weights, trn_idxs=self.model.fit_params.trn_idxs, cbs=cbs, cb_savepath=self.model.fit_params.cb_savepath) print(f"pretraining adversary took {timeit.default_timer()-model_tmr:.3f}s\n") self.adv.save(self.model.fit_params.cb_savepath/'pretrain_adv.h5') # prep for combined training self.adv_loss_func = self.adv_model_builder.loss if is_partially(self.adv_loss_func): self.adv_loss_func = self.adv_loss_func() self.model.fit_params.cbs.insert(0, TargReplace(['targets']+self.adv_targets)) self.model.fit_params.cbs[0].set_model(self.model) self.count = -1 self.adv.freeze_layers() self.model.unfreeze_layers()
[docs] def on_train_end(self) -> None: r''' Save final version of adversary ''' self.adv.save(self.model.fit_params.cb_savepath/'adv.h5')
def _increment(self) -> None: r''' Increments tick and refines adversary if required ''' self.count += 1 if self.count >= self.adv_update_freq: self.count = 0 self.adv.unfreeze_layers() self.model.freeze_layers() cbs = [TargReplace(self.adv_targets)] for c in self.adv_train_cb_partials: cbs.append(c()) self.adv.fit(n_epochs=1, fy=self.model.fit_params.fy, bs=self.model.fit_params.bs, bulk_move=self.model.fit_params.bulk_move, train_on_weights=self.model.fit_params.train_on_weights, trn_idxs=self.model.fit_params.trn_idxs, cbs=cbs, cb_savepath=self.model.fit_params.cb_savepath, visible_bar=False) self.adv.freeze_layers() self.model.unfreeze_layers()
[docs] def on_batch_begin(self) -> None: r''' Slices off adversarial and main-model targets. Increments tick if required. ''' self.adv_y = self.model.fit_params.y[:,-len(self.adv_targets):] self.model.fit_params.y = self.model.fit_params.y[:,:-len(self.adv_targets)] if self.model.fit_params.state == 'train' and self.adv_update_on == 'batch': self._increment()
[docs] def on_fold_begin(self) -> None: r''' Increments tick if required. ''' if self.model.fit_params.state == 'train' and self.adv_update_on == 'fold': self._increment()
[docs] def on_epoch_begin(self) -> None: r''' Increments tick if required. ''' if self.model.fit_params.state == 'train' and self.adv_update_on == 'epoch': self._increment()
def _compute_adv_loss(self) -> Tensor: r''' Computes (weighted) adversarial loss value ''' adv_p = self.adv.model(self.model.fit_params.x) if hasattr(self.adv_loss_func, 'weights'): self.adv_loss_func.weights = self.model.fit_params.w # Proper weighting required else: self.adv_loss_func.weight = self.model.fit_params.w return self.adv_coef*self.adv_loss_func(adv_p, self.adv_y)
[docs] def on_forwards_end(self) -> None: r''' Applies adversarial bonus to main-model loss ''' if self.model.fit_params.state == 'test': return elif self.model.fit_params.state == 'valid': self.adv.model.eval() else: self.adv.model.train() self.model.fit_params.loss_val -= self._compute_adv_loss() # Move to maxima
Read the Docs v: latest
Versions
latest
stable
v0.8.0
v0.7.2
v0.7.1
v0.7.0
v0.6.0
v0.5.1
v0.5.0
v0.4.0.1
v0.3.1
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.

Docs

Access comprehensive developer and user documentation for LUMIN

View Docs

Tutorials

Get tutorials for beginner and advanced researchers demonstrating many of the features of LUMIN

View Tutorials