skorch.callbacks¶
This module serves to elevate callbacks in submodules to the skorch.callback namespace. Remember to define __all__ in each submodule.
-
class
skorch.callbacks.
Callback
[source]¶ Base class for callbacks.
All custom callbacks should inherit from this class. The subclass may override any of the
on_...
methods. It is, however, not necessary to override all of them, since it’s okay if they don’t have any effect.Classes that inherit from this also gain the
get_params
andset_params
method.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params -
initialize
()[source]¶ (Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.
This method should return self.
-
on_batch_begin
(net, X=None, y=None, training=None, **kwargs)[source]¶ Called at the beginning of each batch.
-
on_batch_end
(net, X=None, y=None, training=None, **kwargs)[source]¶ Called at the end of each batch.
-
on_epoch_begin
(net, dataset_train=None, dataset_valid=None, **kwargs)[source]¶ Called at the beginning of each epoch.
-
on_epoch_end
(net, dataset_train=None, dataset_valid=None, **kwargs)[source]¶ Called at the end of each epoch.
-
-
class
skorch.callbacks.
EpochTimer
(**kwargs)[source]¶ Measures the duration of each epoch and writes it to the history with the name
dur
.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
PrintLog
(keys_ignored=None, sink=<built-in function print>, tablefmt='simple', floatfmt='.4f', stralign='right')[source]¶ Print useful information from the model’s history as a table.
By default,
PrintLog
prints everything from the history except for'batches'
.To determine the best loss,
PrintLog
looks for keys that end on'_best'
and associates them with the corresponding loss. E.g.,'train_loss_best'
will be matched with'train_loss'
. TheScoring
callback takes care of creating those entries, which is whyPrintLog
works best in conjunction with that callback.PrintLog
treats keys with the'event_'
prefix in a special way. They are assumed to contain information about occasionally occuring events. TheFalse
orNone
entries (indicating that an event did not occur) are not printed, resulting in empty cells in the table, andTrue
entries are printed with+
symbol.PrintLog
groups all event columns together and pushes them to the right, just before the'dur'
column.Note:
PrintLog
will not result in good outputs if the number of columns varies between epochs, e.g. if the valid loss is only present on every other epoch.Parameters: - keys_ignored : str or list of str (default=None)
Key or list of keys that should not be part of the printed table. Note that in addition to the keys provided by the user, keys such as those starting with ‘event_’ or ending on ‘_best’ are ignored by default.
- sink : callable (default=print)
The target that the output string is sent to. By default, the output is printed to stdout, but the sink could also be a logger, etc.
- tablefmt : str (default=’simple’)
The format of the table. See the documentation of the
tabulate
package for more detail. Can be ‘plain’, ‘grid’, ‘pipe’, ‘html’, ‘latex’, among others.- floatfmt : str (default=’.4f’)
The number formatting. See the documentation of the
tabulate
package for more details.- stralign : str (default=’right’)
The alignment of columns with strings. Can be ‘left’, ‘center’, ‘right’, or
None
(disable alignment). Default is ‘right’ (to be consistent with numerical columns).
Methods
format_row
(row, key, color)For a given row from the table, format it (i.e. initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params table -
format_row
(row, key, color)[source]¶ For a given row from the table, format it (i.e. floating points and color if applicable).
-
class
skorch.callbacks.
ProgressBar
(batches_per_epoch='auto', detect_notebook=True, postfix_keys=None)[source]¶ Display a progress bar for each epoch.
The progress bar includes elapsed and estimated remaining time for the current epoch, the number of batches processed, and other user-defined metrics. The progress bar is erased once the epoch is completed.
ProgressBar
needs to know the total number of batches per epoch in order to display a meaningful progress bar. By default, this number is determined automatically using the dataset length and the batch size. If this heuristic does not work for some reason, you may either specify the number of batches explicitly or let theProgressBar
count the actual number of batches in the previous epoch.For jupyter notebooks a non-ASCII progress bar can be printed instead. To use this feature, you need to have ipywidgets installed.
Parameters: - batches_per_epoch : int, str (default=’auto’)
Either a concrete number or a string specifying the method used to determine the number of batches per epoch automatically.
'auto'
means that the number is computed from the length of the dataset and the batch size.'count'
means that the number is determined by counting the batches in the previous epoch. Note that this will leave you without a progress bar at the first epoch.- detect_notebook : bool (default=True)
If enabled, the progress bar determines if its current environment is a jupyter notebook and switches to a non-ASCII progress bar.
- postfix_keys : list of str (default=[‘train_loss’, ‘valid_loss’])
You can use this list to specify additional info displayed in the progress bar such as metrics and losses. A prerequisite to this is that these values are residing in the history on batch level already, i.e. they must be accessible via
>>> net.history[-1, 'batches', -1, key]
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params in_ipynb set_params
-
class
skorch.callbacks.
LRScheduler
(policy='WarmRestartLR', monitor='train_loss', **kwargs)[source]¶ Callback that sets the learning rate of each parameter group according to some policy.
Parameters: - policy : str or _LRScheduler class (default=’WarmRestartLR’)
Learning rate policy name or scheduler to be used.
- monitor : str or callable (default=None)
Value of the history to monitor or function/callable. In the latter case, the callable receives the net instance as argument and is expected to return the score (float) used to determine the learning rate adjustment.
- kwargs
Additional arguments passed to the lr scheduler.
Attributes: - kwargs
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net, training, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, **kwargs)Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. simulate
(steps, initial_lr)Simulates the learning rate scheduler. get_params set_params
-
class
skorch.callbacks.
WarmRestartLR
(optimizer, min_lr=1e-06, max_lr=0.05, base_period=10, period_mult=2, last_epoch=-1)[source]¶ Stochastic Gradient Descent with Warm Restarts (SGDR) scheduler.
This scheduler sets the learning rate of each parameter group according to stochastic gradient descent with warm restarts (SGDR) policy. This policy simulates periodic warm restarts of SGD, where in each restart the learning rate is initialize to some value and is scheduled to decrease.
Parameters: - optimizer : torch.optimizer.Optimizer instance.
Optimizer algorithm.
- min_lr : float or list of float (default=1e-6)
Minimum allowed learning rate during each period for all param groups (float) or each group (list).
- max_lr : float or list of float (default=0.05)
Maximum allowed learning rate during each period for all param groups (float) or each group (list).
- base_period : int (default=10)
Initial restart period to be multiplied at each restart.
- period_mult : int (default=2)
Multiplicative factor to increase the period between restarts.
- last_epoch : int (default=-1)
The index of the last valid epoch.
References
[1] Ilya Loshchilov and Frank Hutter, 2017, “Stochastic Gradient Descent with Warm Restarts,”. “ICLR” https://arxiv.org/pdf/1608.03983.pdf Methods
load_state_dict
(state_dict)Loads the schedulers state. state_dict
()Returns the state of the scheduler as a dict
.get_lr step
-
class
skorch.callbacks.
GradientNormClipping
(gradient_clip_value=None, gradient_clip_norm_type=2)[source]¶ Clips gradient norm of a module’s parameters.
The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
See
torch.nn.utils.clip_grad_norm_()
for more information.Parameters: - gradient_clip_value : float (default=None)
If not None, clip the norm of all model parameter gradients to this value. The type of the norm is determined by the
gradient_clip_norm_type
parameter and defaults to L2.- gradient_clip_norm_type : float (default=2)
Norm to use when gradient clipping is active. The default is to use L2-norm. Can be ‘inf’ for infinity norm.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(_, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
BatchScoring
(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]¶ Callback that performs generic scoring on batches.
This callback determines the score after each batch and stores it in the net’s history in the column given by
name
. At the end of the epoch, the average of the scores are determined and also stored in the history. Furthermore, it is determined whether this average score is the best score yet and that information is also stored in the history.In contrast to
EpochScoring
, this callback determines the score for each batch and then averages the score at the end of the epoch. This can be disadvantageous for some scores if the batch size is small – e.g. area under the ROC will return incorrect scores in this case. Therefore, it is recommnded to useEpochScoring
unless you really need the scores for each batch.If
y
is None, thescoring
function with signature (model, X, y) must be able to handleX
as aTensor
andy=None
.Parameters: - scoring : None, str, or callable
If None, use the
score
method of the model. If str, it should be a valid sklearn metric (e.g. “f1_score”, “accuracy_score”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoring
parameter in sklearn’sGridSearchCV
et al.- lower_is_better : bool (default=True)
Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better
- on_train : bool (default=False)
Whether this should be called during train or validation.
- name : str or None (default=None)
If not an explicit string, tries to infer the name from the
scoring
argument.- target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
- use_caching : bool (default=True)
Re-use the model’s prediction for computing the loss to calculate the score. Turning this off will result in an additional inference step for each batch.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net, X, y, training, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, X, y, **kwargs)Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_avg_score get_params set_params
-
class
skorch.callbacks.
EpochScoring
(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]¶ Callback that performs generic scoring on predictions.
At the end of each epoch, this callback makes a prediction on train or validation data, determines the score for that prediction and whether it is the best yet, and stores the result in the net’s history.
In case you already computed a score value for each batch you can omit the score computation step by return the value from the history. For example:
>>> def my_score(net, X=None, y=None): ... losses = net.history[-1, 'batches', :, 'my_score'] ... batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size'] ... return np.average(losses, weights=batch_sizes) >>> net = MyNet(callbacks=[ ... ('my_score', Scoring(my_score, name='my_score'))
If you fit with a custom dataset, this callback should work as expected as long as
use_caching=True
which enables the collection ofy
values from the dataset. If you decide to disable the caching of predictions andy
values, you need to write your own scoring function that is able to deal with the dataset and returns a scalar, for example:>>> def ds_accuracy(net, ds, y=None): ... # assume ds yields (X, y), e.g. torchvision.datasets.MNIST ... y_true = [y for _, y in ds] ... y_pred = net.predict(ds) ... return sklearn.metrics.accuracy_score(y_true, y_pred) >>> net = MyNet(callbacks=[ ... EpochScoring(ds_accuracy, use_caching=False)]) >>> ds = torchvision.datasets.MNIST(root=mnist_path) >>> net.fit(ds)
Parameters: - scoring : None, str, or callable (default=None)
If None, use the
score
method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoring
parameter in sklearn’sGridSearchCV
et al.- lower_is_better : bool (default=True)
Whether lower scores should be considered better or worse.
- on_train : bool (default=False)
Whether this should be called during train or validation data.
- name : str or None (default=None)
If not an explicit string, tries to infer the name from the
scoring
argument.- target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
- use_caching : bool (default=True)
Collect labels and predictions (
y_true
andy_pred
) over the course of one epoch and use the cached values for computing the score. The cached values are shared between allEpochScoring
instances. Disabling this will result in an additional inference step for each epoch and an inability to use arbitrary datasets as input (since we don’t know how to extracty_true
from an arbitrary dataset).
Methods
get_test_data
(dataset_train, dataset_valid)Return data needed to perform scoring. initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net, y, y_pred, training, **kwargs)Called at the end of each batch. on_epoch_begin
(net, dataset_train, …)Called at the beginning of each epoch. on_epoch_end
(net, dataset_train, …)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, X, y, **kwargs)Called at the beginning of training. on_train_end
(*args, **kwargs)Called at the end of training. get_params set_params -
get_test_data
(dataset_train, dataset_valid)[source]¶ Return data needed to perform scoring.
This is a convenience method that handles picking of train/valid, different types of input data, use of cache, etc. for you.
Parameters: - dataset_train
Incoming training data or dataset.
- dataset_valid
Incoming validation data or dataset.
Returns: - X_test
Input data used for making the prediction.
- y_test
Target ground truth. If caching was enabled, return cached y_test.
- y_pred : list
The predicted targets. If caching was disabled, the list is empty. If caching was enabled, the list contains the batches of the predictions. It may thus be necessary to concatenate the output before working with it:
y_pred = np.concatenate(y_pred)
-
initialize
()[source]¶ (Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.
This method should return self.
-
class
skorch.callbacks.
Checkpoint
(monitor='valid_loss_best', f_params='params.pt', f_optimizer='optimizer.pt', f_history='history.json', f_pickle=None, fn_prefix='', dirname='', event_name='event_cp', sink=<function noop>)[source]¶ Save the model during training if the given metric improved.
This callback works by default in conjunction with the validation scoring callback since it creates a
valid_loss_best
value in the history which the callback uses to determine if this epoch is save-worthy.You can also specify your own metric to monitor or supply a callback that dynamically evaluates whether the model should be saved in this epoch.
Some or all of the following can be saved:
- model parameters (see
f_params
parameter); - optimizer state (see
f_optimizer
parameter); - training history (see
f_history
parameter); - entire model object (see
f_pickle
parameter).
You can implement your own save protocol by subclassing
Checkpoint
and overridingsave_model()
.This callback writes a bool flag to the history column
event_cp
indicating whether a checkpoint was created or not.Example:
>>> net = MyNet(callbacks=[Checkpoint()]) >>> net.fit(X, y)
Example using a custom monitor where models are saved only in epochs where the validation and the train losses are best:
>>> monitor = lambda net: all(net.history[-1, ( ... 'train_loss_best', 'valid_loss_best')]) >>> net = MyNet(callbacks=[Checkpoint(monitor=monitor)]) >>> net.fit(X, y)
Parameters: - monitor : str, function, None
Value of the history to monitor or callback that determines whether this epoch should lead to a checkpoint. The callback takes the network instance as parameter.
In case
monitor
is set toNone
, the callback will save the network at every epoch.Note: If you supply a lambda expression as monitor, you cannot pickle the wrapper anymore as lambdas cannot be pickled. You can mitigate this problem by using importable functions instead.
- f_params : file-like object, str, None (default=’params.pt’)
File path to the file or file-like object where the model parameters should be saved. Pass
None
to disable saving model parameters.If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are
net
,last_epoch
andlast_batch
. Example to include last epoch number in file name:>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
- f_optimizer : file-like object, str, None (default=’optimizer.pt’)
File path to the file or file-like object where the optimizer state should be saved. Pass
None
to disable saving model parameters.Supports the same format specifiers as
f_params
.- f_history : file-like object, str, None (default=’history.json’)
File path to the file or file-like object where the model training history should be saved. Pass
None
to disable saving history.- f_pickle : file-like object, str, None (default=None)
File path to the file or file-like object where the entire model object should be pickled. Pass
None
to disable pickling.Supports the same format specifiers as
f_params
.- fn_prefix: str (default=’‘)
Prefix for filenames. If
f_params
,f_optimizer
,f_history
, orf_pickle
are strings, they will be prefixed byfn_prefix
.- dirname: str (default=’‘)
Directory where files are stored.
- event_name: str, (default=’event_cp’)
Name of event to be placed in history when checkpoint is triggered. Pass
None
to disable placing events in history.- sink : callable (default=noop)
The target that the information about created checkpoints is sent to. This can be a logger or
print
function (to send to stdout). By default the output is discarded.
Attributes: - f_history_
Methods
get_formatted_files
(net)Returns a dictionary of formatted filenames initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. save_model
(net)Save the model. get_params set_params - model parameters (see
-
class
skorch.callbacks.
EarlyStopping
(monitor='valid_loss', patience=5, threshold=0.0001, threshold_mode='rel', lower_is_better=True, sink=<built-in function print>)[source]¶ Callback for stopping training when scores don’t improve.
Stop training early if a specified monitor metric did not improve in patience number of epochs by at least threshold.
Parameters: - monitor : str (default=’valid_loss’)
Value of the history to monitor to decide whether to stop training or not. The value is expected to be double and is commonly provided by scoring callbacks such as
skorch.callbacks.EpochScoring
.- lower_is_better : bool (default=True)
Whether lower scores should be considered better or worse.
- patience : int (default=5)
Number of epochs to wait for improvement of the monitor value until the training process is stopped.
- threshold : int (default=1e-4)
Ignore score improvements smaller than threshold.
- threshold_mode : str (default=’rel’)
One of rel, abs. Decides whether the threshold value is interpreted in absolute terms or as a fraction of the best score so far (relative)
- sink : callable (default=print)
The target that the information about early stopping is sent to. By default, the output is printed to stdout, but the sink could also be a logger or
noop()
.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, **kwargs)Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
Freezer
(*args, **kwargs)[source]¶ Freeze matching parameters at the start of the first epoch. You may specify a specific point in time (either by epoch number or using a callable) when the parameters are frozen using the
at
parameter.See
ParamMapper
for details.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
Unfreezer
(*args, **kwargs)[source]¶ Inverse operation of
Freezer
.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
Initializer
(*args, **kwargs)[source]¶ Apply any function on matching parameters in the first epoch.
Examples
Use
Initializer
to initialize all dense layer weights with values sampled from an uniform distribution on the beginning of the first epoch:>>> init_fn = partial(torch.nn.init.uniform_, a=-1e-3, b=1e-3) >>> cb = Initializer('dense*.weight', fn=init_fn) >>> net = Net(myModule, callbacks=[cb])
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
ParamMapper
(patterns, fn=<function noop>, at=1, schedule=None)[source]¶ Map arbitrary functions over module parameters filtered by pattern matching.
In the simplest case the function is only applied once at the beginning of a given epoch (at
on_epoch_begin
) but more complex execution schemes (e.g. periodic application) are possible usingat
andscheduler
.Parameters: - patterns : str or callable or list
The pattern(s) to match parameter names against. Patterns are UNIX globbing patterns as understood by
fnmatch()
. Patterns can also be callables which will get called with the parameter name and are regarded as a match when the callable returns a truthy value.This parameter also supports lists of str or callables so that one
ParamMapper
can match a group of parameters.Example:
'linear*.weight'
or['linear0.*', 'linear1.bias']
orlambda name: name.startswith('linear')
.- fn : function
The function to apply to each parameter separately.
- at : int or callable
In case you specify an integer it represents the epoch number the function
fn
is applied to the parameters, in caseat
is a function it will receivenet
as parameter and the function is applied to the parameter onceat
returnsTrue
.- schedule : callable or None
If specified this callable supersedes the static
at
/fn
combination by dynamically returning the function that is applied on the matched parameters. This way you can, for example, create a schedule that periodically freezes and unfreezes layers.The callable’s signature is
schedule(net: NeuralNet) -> callable
.
Notes
When starting the training process after saving and loading a model,
ParamMapper
might re-initialize parts of your model when the history is not saved along with the model. To avoid this, in case you useParamMapper
(or subclasses, e.g.Initializer
) and want to save your model make sure to either (a) use pickle, (b) save and load the history or (c) remove the parameter mapper callbacks before continuing training.Examples
Initialize a layer on first epoch before the first training step:
>>> init = partial(torch.nn.init.uniform_, a=0, b=1) >>> cb = ParamMapper('linear*.weight', at=1, fn=init) >>> net = Net(myModule, callbacks=[cb])
Reset layer initialization if train loss reaches a certain value (e.g. re-initialize on overfit):
>>> at = lambda net: net.history[-1, 'train_loss'] < 0.1 >>> init = partial(torch.nn.init.uniform_, a=0, b=1) >>> cb = ParamMapper('linear0.weight', at=at, fn=init) >>> net = Net(myModule, callbacks=[cb])
Periodically freeze and unfreeze all embedding layers:
>>> def my_sched(net): ... if len(net.history) % 2 == 0: ... return skorch.utils.freeze_parameter ... else: ... return skorch.utils.unfreeze_parameter >>> cb = ParamMapper('embedding*.weight', schedule=my_sched) >>> net = Net(myModule, callbacks=[cb])
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
LoadInitState
(checkpoint)[source]¶ Loads the model, optimizer, and history from a checkpoint into a
NeuralNet
when training begins.Parameters: - checkpoint: :class:`.Checkpoint`
Checkpoint to get filenames from.
Examples
Consider running the following example multiple times:
>>> cp = Checkpoint(monitor='valid_loss_best') >>> load_state = LoadInitState(cp) >>> net = NeuralNet(..., callbacks=[cp, load_state]) >>> net.fit(X, y)
On the first run, the
Checkpoint
saves the model, optimizer, and history when the validation loss is minimized. During the first run, there are no files on disk, thusLoadInitState
will not load anything. When running the example a second time,LoadInitState
will load the best model from the first run and continue training from there.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
TrainEndCheckpoint
(f_params='params.pt', f_optimizer='optimizer.pt', f_history='history.json', f_pickle=None, fn_prefix='train_end_', dirname='', sink=<function noop>)[source]¶ Saves the model parameters, optimizer state, and history at the end of training. The default
fn_prefix
is ‘train_end_’.Parameters: - f_params : file-like object, str, None (default=’params.pt’)
File path to the file or file-like object where the model parameters should be saved. Pass
None
to disable saving model parameters.If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are
net
,last_epoch
andlast_batch
. Example to include last epoch number in file name:>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
- f_optimizer : file-like object, str, None (default=’optimizer.pt’)
File path to the file or file-like object where the optimizer state should be saved. Pass
None
to disable saving model parameters.Supports the same format specifiers as
f_params
.- f_history : file-like object, str, None (default=’history.json’)
File path to the file or file-like object where the model training history should be saved. Pass
None
to disable saving history.- f_pickle : file-like object, str, None (default=None)
File path to the file or file-like object where the entire model object should be pickled. Pass
None
to disable pickling.Supports the same format specifiers as
f_params
.- fn_prefix: str (default=’train_end_’)
Prefix for filenames. If
f_params
,f_optimizer
,f_history
, orf_pickle
are strings, they will be prefixed byfn_prefix
.- dirname: str (default=’‘)
Directory where files are stored.
- sink : callable (default=noop)
The target that the information about created checkpoints is sent to. This can be a logger or
print
function (to send to stdout). By default the output is discarded.
Examples
Consider running the following example multiple times:
>>> train_end_cp = TrainEndCheckpoint(dirname='exp1') >>> load_state = LoadInitState(train_end_cp) >>> net = NeuralNet(..., callbacks=[train_end_cp, load_state]) >>> net.fit(X, y)
After the first run, model parameters, optimizer state, and history are saved into a directory named exp1. On the next run, LoadInitState will load the state from the first run and continue training.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, X, y, training])Called at the beginning of each batch. on_batch_end
(net[, X, y, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net, **kwargs)Called at the end of training. get_params set_params