lumin.nn.training package¶
Submodules¶
lumin.nn.training.fold_train module¶
-
lumin.nn.training.fold_train.
fold_train_ensemble
(fy, n_models, bs, model_builder, callback_partials=None, eval_metrics=None, train_on_weights=True, eval_on_weights=True, patience=10, max_epochs=200, plots=['history', 'realtime'], shuffle_fold=True, shuffle_folds=True, bulk_move=True, savepath=PosixPath('train_weights'), verbose=False, log_output=False, plot_settings=<lumin.plotting.plot_settings.PlotSettings object>, callback_args=None)[source]¶ Main training method for
Model
. Trains a specified numer of models created by aModelBuilder
on data provided by aFoldYielder
, and save them to savepath. Note, this does not return trained models, instead they are saved and must be loaded later. Instead this method returns results of model training. EachModel
is trained on N-1 folds, for aFoldYielder
with N folds, and the remaining fold is used as validation data. Training folds are loaded iteratively, and model evaluation takes place after each fold use (a sub-epoch), rather than after ever use of all folds (epoch). Training continues until:All of the training folds are used max_epoch number of times;
Or validation loss does not decrease for patience number of training folds; (or cycles, if using an
AbsCyclicCallback
);Or a callback triggers trainign to stop, e.g.
OneCycle
Once training is finished, the state with the lowest validation loss is loaded, evaluated, and saved.
Attention
callback_args is now depreciated in favour of callback_partials and will be removed in v0.4
- Parameters
fy (
FoldYielder
) –FoldYielder
interfacing ot training datan_models (
int
) – number of models to trainbs (
int
) – batch size. Number of data points per iterationmodel_builder (
ModelBuilder
) –ModelBuilder
creating the networks to traincallback_partials (
Optional
[List
[partial
]]) – optional list of functools.partial, each of which will a instantiateCallback
when calledeval_metrics (
Optional
[Dict
[str
,EvalMetric
]]) – list of instantiatedEvalMetric
. At the end of training, validation data and model predictions will be passed to each, and the results printed and savedtrain_on_weights (
bool
) – If weights are present in training data, whether to pass them to the loss function during trainingeval_on_weights (
bool
) – If weights are present in validation data, whether to pass them to the loss function during validationpatience (
int
) – number of folds (sub-epochs) or cycles to train without decrease in validation loss before ending training (early stopping)max_epochs (
int
) – maximum number of epochs for which to trainplots (
List
[str
]) – list of string representation of plots to produce. currently: ‘history’: loss history of all models after all training has finished ‘realtime’: live loss evolution during training ‘cycle”: call the plot method of the last (if any)AbsCyclicCallback
listed in callback_partials after every complete model training.shuffle_fold (
bool
) – whether to tellBatchYielder
to shuffle datashuffle_folds (
bool
) – whether to shuffle the order of the trainign foldsbulk_move (
bool
) – whether to pass all training data to device at once, or by minibatch. Bulk moving will be quicker, but may not fit in memory.savepath (
Path
) – path to to which to save model weights and resultsverbose (
bool
) – whether to print out extra information during traininglog_output (
bool
) – whether to save printed results to a log file rather than printing themplot_settings (
PlotSettings
) –PlotSettings
class to control figure appearancecallback_args (
Optional
[List
[Dict
[str
,Any
]]]) – depreciated in favour of callback_partials
- Return type
Tuple
[List
[Dict
[str
,float
]],List
[Dict
[str
,List
[float
]]],List
[Dict
[str
,float
]]]- Returns
results list of validation losses and other eval_metrics results, ordered by model training. Can be used to create an
Ensemble
.histories list of loss histories, ordered by model training
cycle_losses if an
AbsCyclicCallback
was passed, list of validation losses at the end of each cycle, ordered by model training. Can be passed toEnsemble
.