Callbacks¶
Callbacks provide a flexible way to customize the behavior of your
NeuralNet
training without the need to write subclasses.
You will often find callbacks writing to or reading from the
history attribute. Therefore, if you would like to
log the net’s behavior or do something based on the past behavior,
consider using net.history
.
This page will not explain all existing callbacks. For that, please
look at skorch.callbacks
.
Callback base class¶
The base class for each callback is Callback
. If you would
like to write your own callbacks, you should inherit from this class.
A guide and practical example on how to write your own callbacks is
shown in this notebook.
In general, remember this:
They should inherit from the base class.
They should implement at least one of the
on_
methods provided by the parent class (see below).As argument, the methods first get the
NeuralNet
instance, and, where appropriate, the local data (e.g. the data from the current batch). The method should also have**kwargs
in the signature for potentially unused arguments.
Callback methods to override¶
The following methods could potentially be overriden when implementing your own callbacks.
initialize()¶
If you have attributes that should be reset when the model is re-initialized, those attributes should be set in this method.
on_train_begin(net, X, y)¶
Called once at the start of the training process (e.g. when calling
fit
).
on_train_end(net, X, y)¶
Called once at the end of the training process.
on_epoch_begin(net, dataset_train, dataset_valid)¶
Called once at the start of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.
on_epoch_end(net, dataset_train, dataset_valid)¶
Called once at the end of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.
on_batch_begin(net, batch, training)¶
Called once before each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input. Also includes a bool indicating if this is a training batch or not.
on_batch_end(net, batch, training, loss, y_pred)¶
Called once after each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input.
on_grad_computed(net, named_parameters, Xi, yi)¶
Called once per batch after gradients have been computed but before an update step was performed. Gets the module parameters as additional input as well as the batch data. Useful if you want to tinker with gradients.
Setting callback parameters¶
You can set specific callback parameters using the ususal set_params interface on the network by using the callbacks__ prefix and the callback’s name. For example to change the name of the accuracy of the validation set shown during training, you would do:
net = NeuralNetClassifier(...)
net.set_params(callbacks__valid_acc__name="accuracy of valid set")
Changes will be applied on initialization and callbacks that are changed using set_params will be re-initialized.
The name you use to address the callback can be chosen during
initialization of the network and defaults to the class name.
If there is a conflict, the conflicting names will be made unique
by appending a count suffix starting at 1, e.g.
EpochScoring_1
, EpochScoring_2
, etc.
Deactivating callbacks¶
If you would like to (temporarily) deactivate a callback, you can do so by setting its parameter to None. E.g., if you have a callback called ‘my_callback’, you can deactivate it like this:
net = NeuralNet(
module=MyModule,
callbacks=[('my_callback', MyCallback())],
)
# now deactivate 'my_callback':
net.set_params(callbacks__my_callback=None)
This also works with default callbacks.
Deactivating callbacks can be especially useful when you do a
parameter search (say with sklearn
GridSearchCV
). If, for instance, you
use a callback for learning rate scheduling (e.g. via
LRScheduler
) and want to test its usefulness, you can
compare the performance once with and once without the callback.
To completely disable all callbacks, including default callbacks,
set callbacks="disable"
.
Scoring¶
skorch provides two callbacks that calculate scores by default,
EpochScoring
and BatchScoring
. They work basically
in the same way, except that EpochScoring
calculates scores
after each epoch and BatchScoring
after each batch. Use the
former if averaging of batch-wise scores is imprecise (say for AUC
score) and the latter if you are very tight for memory.
In general, these scoring callbacks are useful when the default scores
determined by the NeuralNet
are not enough. They allow you
to easily add new metrics to be logged during training. For an example
of how to add a new score to your model, look at this notebook.
The first argument to both callbacks is name
and should be a
string. This determines the column name of the score shown by the
PrintLog
after each epoch.
Next comes the scoring
parameter. For eager sklearn users, this
should be familiar, since it works exactly the same as in sklearn
GridSearchCV
,
RandomizedSearchCV
,
cross_val_score()
, etc. For those who
are unfamiliar, here is a short explanation:
If you pass a string, sklearn makes a look-up for a score with that name. Examples would be
'f1'
and'roc_auc'
.If you pass
None
, the model’sscore
method is used. By default,NeuralNet
doesn’t provide ascore
method, but you can easily implement your own by subclassing it. If you do, it should takeX
andy
(the target) as input and return a scalar as output.NeuralNetClassifier
andNeuralNetRegressor
have the same score methods as normal sklearn classifiers and regressors.Finally, you can pass a function/callable. In that case, this function should have the signature
func(net, X, y)
and return a scalar.
More on sklearn's model evaluation can be found in this notebook.
The lower_is_better
parameter determines whether lower scores
should be considered as better (e.g. log loss) or worse
(e.g. accuracy). This information is used to write a <name>_best
value to the net’s history
. E.g., if your score is f1 score and is
called 'f1'
, you should set lower_is_better=False
. The
history
will then contain an entry for 'f1'
, which is the
score itself, and an entry for 'f1_best'
, which says whether this
is the as of yet best f1 score.
on_train
is a bool that is used to indicate whether training or validation
data should be used to determine the score. By default, it is set to validation.
Finally, you may have to provide your own target_extractor
. This
should be a function or callable that is applied to the target before
it is passed to the scoring function. The main reason why we need this
is that sometimes, the target is not of a form expected by sklearn and
we need to process it before passing it on.
On top of the two described scoring callbacks, skorch also provides
PassthroughScoring
. This callback does not actually
calculate any new scores. Instead it uses an existing score that is
calculated for each batch (the train loss, for example) and determines
the average of this score, which is then written to the epoch level of
the net’s history
. This is very useful if the score was already
calculated and logged on the batch level and you’re interested to
see the averaged score on the epoch level.
For this callback, you only need to provide the name
of the score
in the history
. Moreover, you may again specify if
lower_is_better
and if the score should be calculated on_train
or not.
Note
Both BatchScoring
and PassthroughScoring
honor the batch size when calculating the average. This can
make a difference when not all batch sizes are equal, which
is typically the case because the last batch of an epoch
contains fewer samples than the rest.
Checkpoint¶
The Checkpoint
callback creates a checkpoint of your model
after each epoch that met certain criteria. By default, the condition
is that the validation loss has improved, however you may change this
by specifying the monitor
parameter. It can take three types of
arguments:
None
: The model is saved after each epoch;string: The model checks whether the last entry in the model
history
for that key is truthy. This is useful in conjunction with scores determined by a scoring callback. They write a<score>_best
entry to thehistory
, which can be used for checkpointing. By default, theCheckpoint
callback looks at'valid_loss_best'
;function or callable: In that case, the function should take the
NeuralNet
instance as sole input and return a bool as output.
To specify where and how your model is saved, change the arguments
starting with f_
:
f_params
: to save model parametersf_optimizer
: to save optimizer statef_history
: to save training historyf_pickle
: to pickle the entire model object.
Please refer to Saving and Loading for more information about restoring your network from a checkpoint.
Learning rate schedulers¶
The LRScheduler
callback allows the use of the various
learning rate schedulers defined in torch.optim.lr_scheduler
,
such as ReduceLROnPlateau
, which
allows dynamic learning rate reduction based on a given value to
monitor, or CyclicLR
, which cycles
the learning rate between two boundaries with a constant frequency.
Here’s a network that uses a callback to set a cyclic learning rate:
from skorch.callbacks import LRScheduler
from torch.optim.lr_scheduler import CyclicLR
net = NeuralNet(
module=MyModule,
callbacks=[
('lr_scheduler',
LRScheduler(policy=CyclicLR,
base_lr=0.001,
max_lr=0.01)),
],
)
As with other callbacks, you can use set_params to set parameters,
and thus search learning rate scheduler parameters using
GridSearchCV
or similar. An
example:
from sklearn.model_selection import GridSearchCV
search = GridSearchCV(
net,
param_grid={'callbacks__lr_scheduler__max_lr': [0.01, 0.1, 1.0]},
)