skorch.callbacks

This module serves to elevate callbacks in submodules to the skorch.callback namespace. Remember to define __all__ in each submodule.

class skorch.callbacks.BatchScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]

Callback that performs generic scoring on batches.

This callback determines the score after each batch and stores it in the net’s history in the column given by name. At the end of the epoch, the average of the scores are determined and also stored in the history. Furthermore, it is determined whether this average score is the best score yet and that information is also stored in the history.

In contrast to EpochScoring, this callback determines the score for each batch and then averages the score at the end of the epoch. This can be disadvantageous for some scores if the batch size is small – e.g. area under the ROC will return incorrect scores in this case. Therefore, it is recommnded to use EpochScoring unless you really need the scores for each batch.

If y is None, the scoring function with signature (model, X, y) must be able to handle X as a Tensor and y=None.

Parameters
scoringNone, str, or callable

If None, use the score method of the model. If str, it should be a valid sklearn metric (e.g. “f1_score”, “accuracy_score”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

lower_is_betterbool (default=True)

Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better.

on_trainbool (default=False)

Whether this should be called during train or validation.

namestr or None (default=None)

If not an explicit string, tries to infer the name from the scoring argument.

target_extractorcallable (default=to_numpy)

This is called on y before it is passed to scoring.

use_cachingbool (default=True)

Re-use the model’s prediction for computing the loss to calculate the score. Turning this off will result in an additional inference step for each batch. Note that the net may override the use of caching.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, batch, training, **kwargs)

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, X, y, **kwargs)

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_avg_score

get_params

set_params

on_batch_end(net, batch, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.Callback[source]

Base class for callbacks.

All custom callbacks should inherit from this class. The subclass may override any of the on_... methods. It is, however, not necessary to override all of them, since it’s okay if they don’t have any effect.

Classes that inherit from this also gain the get_params and set_params method.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_begin(net, batch=None, training=None, **kwargs)[source]

Called at the beginning of each batch.

on_batch_end(net, batch=None, training=None, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the end of each epoch.

on_grad_computed(net, named_parameters, X=None, y=None, training=None, **kwargs)[source]

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, X=None, y=None, **kwargs)[source]

Called at the beginning of training.

on_train_end(net, X=None, y=None, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.Checkpoint(monitor='valid_loss_best', f_params='params.pt', f_optimizer='optimizer.pt', f_criterion='criterion.pt', f_history='history.json', f_pickle=None, fn_prefix='', dirname='', event_name='event_cp', sink=<function noop>, load_best=False, use_safetensors=False, **kwargs)[source]

Save the model during training if the given metric improved.

This callback works by default in conjunction with the validation scoring callback since it creates a valid_loss_best value in the history which the callback uses to determine if this epoch is save-worthy.

You can also specify your own metric to monitor or supply a callback that dynamically evaluates whether the model should be saved in this epoch.

As checkpointing is often used in conjunction with early stopping there is a need to restore the state of the model to the best checkpoint after training is done. The checkpoint callback will do this for you if you wish.

Some or all of the following can be saved:

  • model parameters (see f_params parameter);

  • optimizer state (see f_optimizer parameter);

  • criterion state (see f_criterion parameter);

  • training history (see f_history parameter);

  • entire model object (see f_pickle parameter).

If you’ve created a custom module, e.g. net.mymodule_, you can save that as well by passing f_mymodule.

You can implement your own save protocol by subclassing Checkpoint and overriding save_model().

This callback writes a bool flag to the history column event_cp indicating whether a checkpoint was created or not.

Example:

>>> net = MyNet(callbacks=[Checkpoint()])
>>> net.fit(X, y)

Example using a custom monitor where models are saved only in epochs where the validation and the train losses are best:

>>> monitor = lambda net: all(net.history[-1, (
...     'train_loss_best', 'valid_loss_best')])
>>> net = MyNet(callbacks=[Checkpoint(monitor=monitor)])
>>> net.fit(X, y)
Parameters
monitorstr, function, None

Value of the history to monitor or callback that determines whether this epoch should lead to a checkpoint. The callback takes the network instance as parameter.

In case monitor is set to None, the callback will save the network at every epoch.

Note: If you supply a lambda expression as monitor, you cannot pickle the wrapper anymore as lambdas cannot be pickled. You can mitigate this problem by using importable functions instead.

f_paramsfile-like object, str, None (default=’params.pt’)

File path to the file or file-like object where the model parameters should be saved. Pass None to disable saving model parameters.

If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are net, last_epoch and last_batch. Example to include last epoch number in file name:

>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
f_optimizerfile-like object, str, None (default=’optimizer.pt’)

File path to the file or file-like object where the optimizer state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_criterionfile-like object, str, None (default=’criterion.pt’)

File path to the file or file-like object where the criterion state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_historyfile-like object, str, None (default=’history.json’)

File path to the file or file-like object where the model training history should be saved. Pass None to disable saving history.

f_picklefile-like object, str, None (default=None)

File path to the file or file-like object where the entire model object should be pickled. Pass None to disable pickling.

Supports the same format specifiers as f_params.

fn_prefix: str (default=’’)

Prefix for filenames. If f_params, f_optimizer, f_history, or f_pickle are strings, they will be prefixed by fn_prefix.

dirname: str (default=’’)

Directory where files are stored.

event_name: str, (default=’event_cp’)

Name of event to be placed in history when checkpoint is triggered. Pass None to disable placing events in history.

sinkcallable (default=noop)

The target that the information about created checkpoints is sent to. This can be a logger or print function (to send to stdout). By default the output is discarded.

load_best: bool (default=False)

Load the best checkpoint automatically once training ended. This can be particularly helpful in combination with early stopping as it allows for scoring with the best model, even when early stopping ended training a number of epochs later. Note that this will only work when monitor != None.

use_safetensorsbool (default=False)

Whether to use the safetensors library to persist the state. By default, PyTorch is used, which in turn uses pickle under the hood. When enabling safetensors, be aware that only PyTorch tensors can be stored. Therefore, certain attributes like the optimizer cannot be saved.

Attributes
f_history_

Methods

get_formatted_files(net)

Returns a dictionary of formatted filenames

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net, **kwargs)

Called at the end of training.

save_model(net)

Save the model.

get_params

set_params

get_formatted_files(net)[source]

Returns a dictionary of formatted filenames

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_end(net, **kwargs)[source]

Called at the end of training.

save_model(net)[source]

Save the model.

This function saves some or all of the following:

  • model parameters;

  • optimizer state;

  • criterion state;

  • training history;

  • custom modules;

  • entire model object.

class skorch.callbacks.EarlyStopping(monitor='valid_loss', patience=5, threshold=0.0001, threshold_mode='rel', lower_is_better=True, sink=<built-in function print>, load_best=False)[source]

Callback for stopping training when scores don’t improve.

Stop training early if a specified monitor metric did not improve in patience number of epochs by at least threshold.

Parameters
monitorstr (default=’valid_loss’)

Value of the history to monitor to decide whether to stop training or not. The value is expected to be double and is commonly provided by scoring callbacks such as skorch.callbacks.EpochScoring.

lower_is_betterbool (default=True)

Whether lower scores should be considered better or worse.

patienceint (default=5)

Number of epochs to wait for improvement of the monitor value until the training process is stopped.

thresholdint (default=1e-4)

Ignore score improvements smaller than threshold.

threshold_modestr (default=’rel’)

One of rel, abs. Decides whether the threshold value is interpreted in absolute terms or as a fraction of the best score so far (relative)

sinkcallable (default=print)

The target that the information about early stopping is sent to. By default, the output is printed to stdout, but the sink could also be a logger or noop().

load_best: bool (default=False)

Whether to restore module weights from the epoch with the best value of the monitored quantity. If False, the module weights obtained at the last step of training are used. Note that only the module is restored. Use the Checkpoint callback with the load_best argument set to True if you need to restore the whole object.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, **kwargs)

Called at the beginning of training.

on_train_end(net, **kwargs)

Called at the end of training.

get_params

set_params

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_begin(net, **kwargs)[source]

Called at the beginning of training.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.EpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]

Callback that performs generic scoring on predictions.

At the end of each epoch, this callback makes a prediction on train or validation data, determines the score for that prediction and whether it is the best yet, and stores the result in the net’s history.

In case you already computed a score value for each batch you can omit the score computation step by return the value from the history. For example:

>>> def my_score(net, X=None, y=None):
...     losses = net.history[-1, 'batches', :, 'my_score']
...     batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size']
...     return np.average(losses, weights=batch_sizes)
>>> net = MyNet(callbacks=[
...     ('my_score', Scoring(my_score, name='my_score'))

If you fit with a custom dataset, this callback should work as expected as long as use_caching=True which enables the collection of y values from the dataset. If you decide to disable the caching of predictions and y values, you need to write your own scoring function that is able to deal with the dataset and returns a scalar, for example:

>>> def ds_accuracy(net, ds, y=None):
...     # assume ds yields (X, y), e.g. torchvision.datasets.MNIST
...     y_true = [y for _, y in ds]
...     y_pred = net.predict(ds)
...     return sklearn.metrics.accuracy_score(y_true, y_pred)
>>> net = MyNet(callbacks=[
...     EpochScoring(ds_accuracy, use_caching=False)])
>>> ds = torchvision.datasets.MNIST(root=mnist_path)
>>> net.fit(ds)
Parameters
scoringNone, str, or callable (default=None)

If None, use the score method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

lower_is_betterbool (default=True)

Whether lower scores should be considered better or worse.

on_trainbool (default=False)

Whether this should be called during train or validation data.

namestr or None (default=None)

If not an explicit string, tries to infer the name from the scoring argument.

target_extractorcallable (default=to_numpy)

This is called on y before it is passed to scoring.

use_cachingbool (default=True)

Collect labels and predictions (y_true and y_pred) over the course of one epoch and use the cached values for computing the score. The cached values are shared between all EpochScoring instances. Disabling this will result in an additional inference step for each epoch and an inability to use arbitrary datasets as input (since we don’t know how to extract y_true from an arbitrary dataset). Note that the net may override the use of caching.

Methods

get_test_data(dataset_train, dataset_valid, ...)

Return data needed to perform scoring.

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, batch, y_pred, training, ...)

Called at the end of each batch.

on_epoch_begin(net, dataset_train, ...)

Called at the beginning of each epoch.

on_epoch_end(net, dataset_train, ...)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, X, y, **kwargs)

Called at the beginning of training.

on_train_end(*args, **kwargs)

Called at the end of training.

get_params

set_params

get_test_data(dataset_train, dataset_valid, use_caching)[source]

Return data needed to perform scoring.

This is a convenience method that handles picking of train/valid, different types of input data, use of cache, etc. for you.

Parameters
dataset_train

Incoming training data or dataset.

dataset_valid

Incoming validation data or dataset.

use_cachingbool

Whether caching of inference is being used.

Returns
X_test

Input data used for making the prediction.

y_test

Target ground truth. If caching was enabled, return cached y_test.

y_predlist

The predicted targets. If caching was disabled, the list is empty. If caching was enabled, the list contains the batches of the predictions. It may thus be necessary to concatenate the output before working with it: y_pred = np.concatenate(y_pred)

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, batch, y_pred, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train, dataset_valid, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]

Called at the end of each epoch.

on_train_end(*args, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.EpochTimer(**kwargs)[source]

Measures the duration of each epoch and writes it to the history with the name dur.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net, **kwargs)

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_params

set_params

on_epoch_begin(net, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.Freezer(*args, **kwargs)[source]

Freeze matching parameters at the start of the first epoch. You may specify a specific point in time (either by epoch number or using a callable) when the parameters are frozen using the at parameter.

See ParamMapper for details.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net, **kwargs)

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

filter_parameters

get_params

named_parameters

set_params

class skorch.callbacks.GradientNormClipping(gradient_clip_value=None, gradient_clip_norm_type=2)[source]

Clips gradient norm of a module’s parameters.

The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

See torch.nn.utils.clip_grad_norm_() for more information.

Parameters
gradient_clip_valuefloat (default=None)

If not None, clip the norm of all model parameter gradients to this value. The type of the norm is determined by the gradient_clip_norm_type parameter and defaults to L2.

gradient_clip_norm_typefloat (default=2)

Norm to use when gradient clipping is active. The default is to use L2-norm. Can be ‘inf’ for infinity norm.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(_, named_parameters, **kwargs)

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_params

set_params

on_grad_computed(_, named_parameters, **kwargs)[source]

Called once per batch after gradients have been computed but before an update step was performed.

class skorch.callbacks.Initializer(*args, **kwargs)[source]

Apply any function on matching parameters in the first epoch.

Examples

Use Initializer to initialize all dense layer weights with values sampled from an uniform distribution on the beginning of the first epoch:

>>> init_fn = partial(torch.nn.init.uniform_, a=-1e-3, b=1e-3)
>>> cb = Initializer('dense*.weight', fn=init_fn)
>>> net = Net(myModule, callbacks=[cb])

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net, **kwargs)

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

filter_parameters

get_params

named_parameters

set_params

class skorch.callbacks.InputShapeSetter(param_name='input_dim', input_dim_fn=None, module_name='module')[source]

Sets the input dimension of the PyTorch module to the input dimension of the training data. By default the last dimension of X (X.shape[-1]) will be used.

This can be of use when the shape of X is not known beforehand, e.g. when using a skorch model within an sklearn pipeline and grid-searching feature transformers, or using feature selection methods.

Basic usage:

>>> class MyModule(torch.nn.Module):
...     def __init__(self, input_dim=1):
...         super().__init__()
...         self.layer = torch.nn.Linear(input_dim, 3)
... # ...
>>> X1 = np.zeros(100, 5)
>>> X2 = np.zeros(100, 3)
>>> y = np.zeros(100)
>>> net = NeuralNetClassifier(MyModule, callbacks=[InputShapeSetter()])
>>> net.fit(X1, y)  # self.module_.layer.in_features == 5
>>> net.fit(X2, y)  # self.module_.layer.in_features == 3
Parameters
param_namestr (default=’input_dim’)

The parameter name is the parameter your model uses to define the input dimension in its __init__ method.

input_dim_fncallable, None (default=None)

In case your X value is more complex and deriving the input dimension is not as easy as X.shape[-1] you can pass a callable to this parameter which takes X and returns the input dimension.

module_namestr (default=’module’)

Only needs change when you are using more than one module in your skorch model (e.g., in case of GANs).

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, X, y, **kwargs)

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_input_dim

get_params

set_params

on_train_begin(net, X, y, **kwargs)[source]

Called at the beginning of training.

class skorch.callbacks.LRScheduler(policy='WarmRestartLR', monitor='train_loss', event_name='event_lr', step_every='epoch', **kwargs)[source]

Callback that sets the learning rate of each parameter group according to some policy.

Parameters
policystr or _LRScheduler class (default=’WarmRestartLR’)

Learning rate policy name or scheduler to be used.

monitorstr or callable (default=None)

Value of the history to monitor or function/callable. In the latter case, the callable receives the net instance as argument and is expected to return the score (float) used to determine the learning rate adjustment.

event_name: str, (default=’event_lr’)

Name of event to be placed in history when the scheduler takes a step. Pass None to disable placing events in history. Note: This feature works only for pytorch version >=1.4

step_every: str, (default=’epoch’)
Value for when to apply the learning scheduler step. Can be either ‘batch’

or ‘epoch’.

kwargs

Additional arguments passed to the lr scheduler.

Attributes
kwargs

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, training, **kwargs)

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, **kwargs)

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

simulate(steps, initial_lr)

Simulates the learning rate scheduler.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_begin(net, **kwargs)[source]

Called at the beginning of training.

simulate(steps, initial_lr)[source]

Simulates the learning rate scheduler.

Parameters
steps: int

Number of steps to simulate

initial_lr: float

Initial learning rate

Returns
lrs: numpy ndarray

Simulated learning rates

class skorch.callbacks.LoadInitState(checkpoint, use_safetensors=False)[source]

Loads the model, optimizer, and history from a checkpoint into a NeuralNet when training begins.

Parameters
checkpoint: :class:`.Checkpoint`

Checkpoint to get filenames from.

use_safetensorsbool (default=False)

Whether to use the safetensors library to load the state. By default, PyTorch is used, which in turn uses pickle under the hood. When the state was saved using safetensors, (e.g. by enabling it with the Checkpoint), you should set this to True.

Examples

Consider running the following example multiple times:

>>> cp = Checkpoint(monitor='valid_loss_best')
>>> load_state = LoadInitState(cp)
>>> net = NeuralNet(..., callbacks=[cp, load_state])
>>> net.fit(X, y)

On the first run, the Checkpoint saves the model, optimizer, and history when the validation loss is minimized. During the first run, there are no files on disk, thus LoadInitState will not load anything. When running the example a second time, LoadInitState will load the best model from the first run and continue training from there.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_train_begin(net, X=None, y=None, **kwargs)[source]

Called at the beginning of training.

class skorch.callbacks.MlflowLogger(run=None, client=None, create_artifact=True, terminate_after_train=True, log_on_batch_end=False, log_on_epoch_end=True, batch_suffix=None, epoch_suffix=None, keys_ignored=None)[source]

Logs results from history and artifact to Mlflow

“MLflow is an open source platform for managing the end-to-end machine learning lifecycle” (MLflow: A Tool for Managing the Machine Learning Lifecycle)

Use this callback to automatically log your metrics and create/log artifacts to mlflow.

The best way to log additional information is to log directly to the experiment object or subclass the on_* methods.

To use this logger, you first have to install Mlflow:

$ python -m pip install mlflow
Parameters
runmlflow.entities.Run (default=None)

Instantiated mlflow.entities.Run class. By default (if set to None), mlflow.active_run() is used to get the current run.

clientmlflow.tracking.MlflowClient (default=None)

Instantiated mlflow.tracking.MlflowClient class. By default (if set to None), MlflowClient() is used, which by default has:

create_artifactbool (default=True)

Whether to create artifacts for the network’s params, optimizer, criterion and history. See Saving and Loading

terminate_after_trainbool (default=True)

Whether to terminate the Run object once training finishes.

log_on_batch_endbool (default=False)

Whether to log loss and other metrics on batch level.

log_on_epoch_endbool (default=True)

Whether to log loss and other metrics on epoch level.

batch_suffixstr (default=None)

A string that will be appended to all logged keys. By default (if set to None) '_batch' is used if batch and epoch logging are both enabled and no suffix is used otherwise.

epoch_suffixstr (default=None)

A string that will be appended to all logged keys. By default (if set to None) '_epoch' is used if batch and epoch logging are both enabled and no suffix is used otherwise.

keys_ignoredstr or list of str (default=None)

Key or list of keys that should not be logged to Mlflow. Note that in addition to the keys provided by the user, keys such as those starting with 'event_' or ending on '_best' are ignored by default.

Examples

Mlflow fluent API:

>>> import mlflow
>>> net = NeuralNetClassifier(net, callbacks=[MLflowLogger()])
>>> with mlflow.start_run():
...     net.fit(X, y)

Custom run and client:

>>> from mlflow.tracking import MlflowClient
>>> client = MlflowClient()
>>> experiment = client.get_experiment_by_name('Default')
>>> run = client.create_run(experiment.experiment_id)
>>> net = NeuralNetClassifier(..., callbacks=[MlflowLogger(run, client)])
>>> net.fit(X, y)

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, training, **kwargs)

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, **kwargs)

Called at the beginning of training.

on_train_end(net, **kwargs)

Called at the end of training.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_begin(net, **kwargs)[source]

Called at the beginning of training.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.NeptuneLogger(run, *, log_on_batch_end=False, close_after_train=True, keys_ignored=None, base_namespace='training')[source]

Logs model metadata and training metrics to Neptune.

Neptune is a lightweight experiment-tracking tool. You can read more about it here: https://neptune.ai

Use this callback to automatically log all interesting values from your net’s history to Neptune.

The best way to log additional information is to log directly to the run object.

To monitor resource consumption, install psutil:

$ python -m pip install psutil

You can view example experiment logs here: https://app.neptune.ai/o/common/org/skorch-integration/e/SKOR-32/all

Parameters
runneptune.Run or neptune.handler.Handler

Instantiated Run or Handler class.

log_on_batch_endbool (default=False)

Whether to log loss and other metrics on batch level.

close_after_trainbool (default=True)

Whether to close the Run object once training finishes. Set this parameter to False if you want to continue logging to the same run or if you use it as a context manager.

keys_ignoredstr or list of str (default=None)

Key or list of keys that should not be logged to Neptune. Note that in addition to the keys provided by the user, keys such as those starting with 'event_' or ending on '_best' are ignored by default.

base_namespace: str

Namespace (folder) under which all metadata logged by the NeptuneLogger will be stored. Defaults to “training”.

Examples

$ # Install Neptune $ python -m pip install neptune

>>> # Create a Neptune run
>>> import neptune
>>> from neptune.types import File
>>> # This example uses the API token for anonymous users.
>>> # For your own projects, use the token associated with your neptune.ai account.
>>> run = neptune.init_run(
...     api_token=neptune.ANONYMOUS_API_TOKEN,
...     project='shared/skorch-integration',
...     name='skorch-basic-example',
...     source_files=['skorch_example.py'],
... )
>>> # Create a NeptuneLogger callback
>>> neptune_logger = NeptuneLogger(run, close_after_train=False)
>>> # Pass the logger to the net callbacks argument
>>> net = NeuralNetClassifier(
...           ClassifierModule,
...           max_epochs=20,
...           lr=0.01,
...           callbacks=[neptune_logger, Checkpoint(dirname="./checkpoints")])
>>> net.fit(X, y)
>>> # Save the checkpoints to Neptune
>>> neptune_logger.run["checkpoints"].upload_files("./checkpoints")
>>> # Log additional metrics after training has finished
>>> from sklearn.metrics import roc_auc_score
>>> y_proba = net.predict_proba(X)
>>> auc = roc_auc_score(y, y_proba[:, 1])
>>> neptune_logger.run["roc_auc_score"].log(auc)
>>> # Log charts, such as an ROC curve
>>> from sklearn.metrics import RocCurveDisplay
>>> roc_plot = RocCurveDisplay.from_estimator(net, X, y)
>>> neptune_logger.run["roc_curve"].upload(File.as_html(roc_plot.figure_))
>>> # Log the net object after training
>>> net.save_params(f_params='basic_model.pkl')
>>> neptune_logger.run["basic_model"].upload(File('basic_model.pkl'))
>>> # Close the run
>>> neptune_logger.run.stop()
Attributes
.. _Neptune: https://www.neptune.ai

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, **kwargs)

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Automatically log values from the last history step.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, X, y, **kwargs)

Called at the beginning of training.

on_train_end(net, **kwargs)

Called at the end of training.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Automatically log values from the last history step.

on_train_begin(net, X, y, **kwargs)[source]

Called at the beginning of training.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.ParamMapper(patterns, fn=<function noop>, at=1, schedule=None)[source]

Map arbitrary functions over module parameters filtered by pattern matching.

In the simplest case the function is only applied once at the beginning of a given epoch (at on_epoch_begin) but more complex execution schemes (e.g. periodic application) are possible using at and scheduler.

Parameters
patternsstr or callable or list

The pattern(s) to match parameter names against. Patterns are UNIX globbing patterns as understood by fnmatch(). Patterns can also be callables which will get called with the parameter name and are regarded as a match when the callable returns a truthy value.

This parameter also supports lists of str or callables so that one ParamMapper can match a group of parameters.

Example: 'linear*.weight' or ['linear0.*', 'linear1.bias'] or lambda name: name.startswith('linear').

fnfunction

The function to apply to each parameter separately.

atint or callable

In case you specify an integer it represents the epoch number the function fn is applied to the parameters, in case at is a function it will receive net as parameter and the function is applied to the parameter once at returns True.

schedulecallable or None

If specified this callable supersedes the static at/fn combination by dynamically returning the function that is applied on the matched parameters. This way you can, for example, create a schedule that periodically freezes and unfreezes layers.

The callable’s signature is schedule(net: NeuralNet) -> callable.

Notes

When starting the training process after saving and loading a model, ParamMapper might re-initialize parts of your model when the history is not saved along with the model. To avoid this, in case you use ParamMapper (or subclasses, e.g. Initializer) and want to save your model make sure to either (a) use pickle, (b) save and load the history or (c) remove the parameter mapper callbacks before continuing training.

Examples

Initialize a layer on first epoch before the first training step:

>>> init = partial(torch.nn.init.uniform_, a=0, b=1)
>>> cb = ParamMapper('linear*.weight', at=1, fn=init)
>>> net = Net(myModule, callbacks=[cb])

Reset layer initialization if train loss reaches a certain value (e.g. re-initialize on overfit):

>>> at = lambda net: net.history[-1, 'train_loss'] < 0.1
>>> init = partial(torch.nn.init.uniform_, a=0, b=1)
>>> cb = ParamMapper('linear0.weight', at=at, fn=init)
>>> net = Net(myModule, callbacks=[cb])

Periodically freeze and unfreeze all embedding layers:

>>> def my_sched(net):
...    if len(net.history) % 2 == 0:
...        return skorch.utils.freeze_parameter
...    else:
...        return skorch.utils.unfreeze_parameter
>>> cb = ParamMapper('embedding*.weight', schedule=my_sched)
>>> net = Net(myModule, callbacks=[cb])

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net, **kwargs)

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

filter_parameters

get_params

named_parameters

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_begin(net, **kwargs)[source]

Called at the beginning of each epoch.

class skorch.callbacks.PassthroughScoring(name, lower_is_better=True, on_train=False)[source]

Creates scores on epoch level based on batch level scores

This callback doesn’t calculate any new scores but instead passes through a score that was created on the batch level. Based on that score, an average across the batch is created (honoring the batch size) and recorded in the history for the given epoch.

Use this callback when there already is a score calculated on the batch level. If that score has yet to be calculated, use BatchScoring instead.

Parameters
namestr

Name of the score recorded on a batch level in the history.

lower_is_betterbool (default=True)

Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better.

on_trainbool (default=False)

Whether this should be called during train or validation.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_avg_score

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.PrintLog(keys_ignored=None, sink=<built-in function print>, tablefmt='simple', floatfmt='.4f', stralign='right')[source]

Print useful information from the model’s history as a table.

By default, PrintLog prints everything from the history except for 'batches'.

To determine the best loss, PrintLog looks for keys that end on '_best' and associates them with the corresponding loss. E.g., 'train_loss_best' will be matched with 'train_loss'. The skorch.callbacks.EpochScoring callback takes care of creating those entries, which is why PrintLog works best in conjunction with that callback.

PrintLog treats keys with the 'event_' prefix in a special way. They are assumed to contain information about occasionally occuring events. The False or None entries (indicating that an event did not occur) are not printed, resulting in empty cells in the table, and True entries are printed with + symbol. PrintLog groups all event columns together and pushes them to the right, just before the 'dur' column.

Note: PrintLog will not result in good outputs if the number of columns varies between epochs, e.g. if the valid loss is only present on every other epoch.

Parameters
keys_ignoredstr or list of str (default=None)

Key or list of keys that should not be part of the printed table. Note that in addition to the keys provided by the user, keys such as those starting with 'event_' or ending on '_best' are ignored by default.

sinkcallable (default=print)

The target that the output string is sent to. By default, the output is printed to stdout, but the sink could also be a logger, etc.

tablefmtstr (default=’simple’)

The format of the table. See the documentation of the tabulate package for more detail. Can be ‘plain’, ‘grid’, ‘pipe’, ‘html’, ‘latex’, among others.

floatfmtstr (default=’.4f’)

The number formatting. See the documentation of the tabulate package for more details.

stralignstr (default=’right’)

The alignment of columns with strings. Can be ‘left’, ‘center’, ‘right’, or None (disable alignment). Default is ‘right’ (to be consistent with numerical columns).

Methods

format_row(row, key, color)

For a given row from the table, format it (i.e.

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_params

set_params

table

format_row(row, key, color)[source]

For a given row from the table, format it (i.e. floating points and color if applicable).

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.ProgressBar(batches_per_epoch='auto', detect_notebook=True, postfix_keys=None)[source]

Display a progress bar for each epoch.

The progress bar includes elapsed and estimated remaining time for the current epoch, the number of batches processed, and other user-defined metrics. The progress bar is erased once the epoch is completed.

ProgressBar needs to know the total number of batches per epoch in order to display a meaningful progress bar. By default, this number is determined automatically using the dataset length and the batch size. If this heuristic does not work for some reason, you may either specify the number of batches explicitly or let the ProgressBar count the actual number of batches in the previous epoch.

For jupyter notebooks a non-ASCII progress bar can be printed instead. To use this feature, you need to have ipywidgets installed.

Parameters
batches_per_epochint, str (default=’auto’)

Either a concrete number or a string specifying the method used to determine the number of batches per epoch automatically. 'auto' means that the number is computed from the length of the dataset and the batch size. 'count' means that the number is determined by counting the batches in the previous epoch. Note that this will leave you without a progress bar at the first epoch.

detect_notebookbool (default=True)

If enabled, the progress bar determines if its current environment is a jupyter notebook and switches to a non-ASCII progress bar.

postfix_keyslist of str (default=[‘train_loss’, ‘valid_loss’])

You can use this list to specify additional info displayed in the progress bar such as metrics and losses. A prerequisite to this is that these values are residing in the history on batch level already, i.e. they must be accessible via

>>> net.history[-1, 'batches', -1, key]

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, **kwargs)

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_params

in_ipynb

set_params

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.SacredLogger(experiment, log_on_batch_end=False, log_on_epoch_end=True, batch_suffix=None, epoch_suffix=None, keys_ignored=None)[source]

Logs results from history to Sacred.

Sacred is a tool to help you configure, organize, log and reproduce experiments. Developed at IDSIA. See https://github.com/IDSIA/sacred.

Use this callback to automatically log all interesting values from your net’s history to Sacred.

If you want to log additional information, you can simply add it to History. See the documentation on Callbacks, and Scoring for more information. Alternatively you can subclass this callback and extend the on_* methods.

To use this logger, you first have to install Sacred:

python -m pip install sacred

You might also install pymongo to use a mongodb backend. See the upstream documentation for more details. Once you have installed it, you can set up a simple experiment and pass this Logger as a callback to your skorch estimator:

Parameters
experimentsacred.Experiment

Instantiated Experiment class.

log_on_batch_endbool (default=False)

Whether to log loss and other metrics on batch level.

log_on_epoch_endbool (default=True)

Whether to log loss and other metrics on epoch level.

batch_suffixstr (default=None)

A string that will be appended to all logged keys. By default (if set to None) “_batch” is used if batch and epoch logging are both enabled and no suffix is used otherwise.

epoch_suffixstr (default=None)

A string that will be appended to all logged keys. By default (if set to None) “_epoch” is used if batch and epoch logging are both enabled and no suffix is used otherwise.

keys_ignoredstr or list of str (default=None)

Key or list of keys that should not be logged to Sacred. Note that in addition to the keys provided by the user, keys such as those starting with 'event_' or ending on '_best' are ignored by default.

Examples

>>> # contents of sacred-experiment.py
>>> import numpy as np
>>> from sacred import Experiment
>>> from sklearn.datasets import make_classification
>>> from skorch.callbacks.logging import SacredLogger
>>> from skorch.callbacks.scoring import EpochScoring
>>> from skorch import NeuralNetClassifier
>>> from skorch.toy import make_classifier
>>> ex = Experiment()
>>> @ex.config
>>> def my_config():
...     max_epochs = 20
...     lr = 0.01
>>> X, y = make_classification()
>>> X, y = X.astype(np.float32), y.astype(np.int64)
>>> @ex.automain
>>> def main(_run, max_epochs, lr):
...     # Take care to add additional scoring callbacks *before* the logger.
...     net = NeuralNetClassifier(
...         make_classifier(),
...         max_epochs=max_epochs,
...         lr=0.01,
...         callbacks=[EpochScoring("f1"), SacredLogger(_run)]
...     )
...     # now fit your estimator to your data
...     net.fit(X, y)

Then call this from the command line, e.g. like this:

python sacred-script.py with max_epochs=15

You can also change other options on the command line and optionally specify a backend.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, **kwargs)

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Automatically log values from the last history step.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Automatically log values from the last history step.

class skorch.callbacks.TensorBoard(writer, close_after_train=True, keys_ignored=None, key_mapper=<function rename_tensorboard_key>)[source]

Logs results from history to TensorBoard

“TensorBoard provides the visualization and tooling needed for machine learning experimentation” (offical docs).

Use this callback to automatically log all interesting values from your net’s history to tensorboard after each epoch.

Parameters
writertorch.utils.tensorboard.writer.SummaryWriter

Instantiated SummaryWriter class.

close_after_trainbool (default=True)

Whether to close the SummaryWriter object once training finishes. Set this parameter to False if you want to continue logging with the same writer or if you use it as a context manager.

keys_ignoredstr or list of str (default=None)

Key or list of keys that should not be logged to tensorboard. Note that in addition to the keys provided by the user, keys such as those starting with 'event_' or ending on '_best' are ignored by default.

key_mappercallable or function (default=rename_tensorboard_key)

This function maps a key name from the history to a tag in tensorboard. This is useful because tensorboard can automatically group similar tags if their names start with the same prefix, followed by a forward slash. By default, this callback will prefix all keys that start with “train” or “valid” with the “Loss/” prefix.

Examples

Here is the standard way of using the callback:

>>> # Example: normal usage
>>> from skorch.callbacks import TensorBoard
>>> from torch.utils.tensorboard import SummaryWriter
>>> writer = SummaryWriter(...)
>>> net = NeuralNet(..., callbacks=[TensorBoard(writer)])
>>> net.fit(X, y)

The best way to log additional information is to subclass this callback and add your code to one of the on_* methods.

>>> # Example: log the bias parameter as a histogram
>>> def extract_bias(module):
...     return module.hidden.bias
>>> # override on_epoch_end
>>> class MyTensorBoard(TensorBoard):
...     def on_epoch_end(self, net, **kwargs):
...         bias = extract_bias(net.module_)
...         epoch = net.history[-1, 'epoch']
...         self.writer.add_histogram('bias', bias, global_step=epoch)
...         super().on_epoch_end(net, **kwargs)  # call super last
>>> # other code
>>> net = NeuralNet(..., callbacks=[MyTensorBoard(writer)])

Methods

add_scalar_maybe(history, key, tag[, ...])

Add a scalar value from the history to TensorBoard

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net, **kwargs)

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Automatically log values from the last history step.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net, **kwargs)

Called at the end of training.

get_params

set_params

add_scalar_maybe(history, key, tag, global_step=None)[source]

Add a scalar value from the history to TensorBoard

Will catch errors like missing keys or wrong value types.

Parameters
historyskorch.History

History object saved as attribute on the neural net.

keystr

Key of the desired value in the history.

tagstr

Name of the tag used in TensorBoard.

global_stepint or None

Global step value to record.

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Automatically log values from the last history step.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.TrainEndCheckpoint(f_params='params.pt', f_optimizer='optimizer.pt', f_criterion='criterion.pt', f_history='history.json', f_pickle=None, fn_prefix='train_end_', dirname='', use_safetensors=False, sink=<function noop>, **kwargs)[source]

Saves the model parameters, optimizer state, and history at the end of training. The default fn_prefix is 'train_end_'.

Parameters
f_paramsfile-like object, str, None (default=’params.pt’)

File path to the file or file-like object where the model parameters should be saved. Pass None to disable saving model parameters.

If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are net, last_epoch and last_batch. Example to include last epoch number in file name:

>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
f_optimizerfile-like object, str, None (default=’optimizer.pt’)

File path to the file or file-like object where the optimizer state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_criterionfile-like object, str, None (default=’criterion.pt’)

File path to the file or file-like object where the criterion state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_historyfile-like object, str, None (default=’history.json’)

File path to the file or file-like object where the model training history should be saved. Pass None to disable saving history.

f_picklefile-like object, str, None (default=None)

File path to the file or file-like object where the entire model object should be pickled. Pass None to disable pickling.

Supports the same format specifiers as f_params.

fn_prefix: str (default=’train_end_’)

Prefix for filenames. If f_params, f_optimizer, f_history, or f_pickle are strings, they will be prefixed by fn_prefix.

dirname: str (default=’’)

Directory where files are stored.

use_safetensorsbool (default=False)

Whether to use the safetensors library to persist the state. By default, PyTorch is used, which in turn uses pickle under the hood. When enabling safetensors, be aware that only PyTorch tensors can be stored. Therefore, certain attributes like the optimizer cannot be saved.

sinkcallable (default=noop)

The target that the information about created checkpoints is sent to. This can be a logger or print function (to send to stdout). By default the output is discarded.

Examples

Consider running the following example multiple times:

>>> train_end_cp = TrainEndCheckpoint(dirname='exp1')
>>> load_state = LoadInitState(train_end_cp)
>>> net = NeuralNet(..., callbacks=[train_end_cp, load_state])
>>> net.fit(X, y)

After the first run, model parameters, optimizer state, and history are saved into a directory named exp1. On the next run, LoadInitState will load the state from the first run and continue training.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net, **kwargs)

Called at the end of training.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.Unfreezer(*args, **kwargs)[source]

Inverse operation of Freezer.

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net, **kwargs)

Called at the beginning of each epoch.

on_epoch_end(net[, dataset_train, dataset_valid])

Called at the end of each epoch.

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net[, X, y])

Called at the beginning of training.

on_train_end(net[, X, y])

Called at the end of training.

filter_parameters

get_params

named_parameters

set_params

class skorch.callbacks.WandbLogger(wandb_run, save_model=True, keys_ignored=None)[source]

Logs best model and metrics to Weights & Biases

Use this callback to automatically log best trained model, all metrics from your net’s history, model topology and computer resources to Weights & Biases after each epoch.

Every file saved in wandb_run.dir is automatically logged to W&B servers.

See example run

Parameters
wandb_runwandb.wandb_run.Run

wandb Run used to log data.

save_modelbool (default=True)

Whether to save a checkpoint of the best model and upload it to your Run on W&B servers.

keys_ignoredstr or list of str (default=None)

Key or list of keys that should not be logged to wandb. Note that in addition to the keys provided by the user, keys such as those starting with 'event_' or ending on '_best' are ignored by default.

Examples

>>> # Install wandb
... python -m pip install wandb
>>> import wandb
>>> from skorch.callbacks import WandbLogger
>>> # Create a wandb Run
... wandb_run = wandb.init()
>>> # Alternative: Create a wandb Run without having a W&B account
... wandb_run = wandb.init(anonymous="allow)
>>> # Log hyper-parameters (optional)
... wandb_run.config.update({"learning rate": 1e-3, "batch size": 32})
>>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)])
>>> net.fit(X, y)

Methods

initialize()

(Re-)Set the initial state of the callback.

on_batch_begin(net[, batch, training])

Called at the beginning of each batch.

on_batch_end(net[, batch, training])

Called at the end of each batch.

on_epoch_begin(net[, dataset_train, ...])

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)

Log values from the last history step and save best model

on_grad_computed(net, named_parameters[, X, ...])

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, **kwargs)

Log model topology and add a hook for gradients

on_train_end(net[, X, y])

Called at the end of training.

get_params

set_params

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Log values from the last history step and save best model

on_train_begin(net, **kwargs)[source]

Log model topology and add a hook for gradients

class skorch.callbacks.WarmRestartLR(optimizer, min_lr=1e-06, max_lr=0.05, base_period=10, period_mult=2, last_epoch=-1)[source]

Stochastic Gradient Descent with Warm Restarts (SGDR) scheduler.

This scheduler sets the learning rate of each parameter group according to stochastic gradient descent with warm restarts (SGDR) policy. This policy simulates periodic warm restarts of SGD, where in each restart the learning rate is initialize to some value and is scheduled to decrease.

Parameters
optimizertorch.optimizer.Optimizer instance.

Optimizer algorithm.

min_lrfloat or list of float (default=1e-6)

Minimum allowed learning rate during each period for all param groups (float) or each group (list).

max_lrfloat or list of float (default=0.05)

Maximum allowed learning rate during each period for all param groups (float) or each group (list).

base_periodint (default=10)

Initial restart period to be multiplied at each restart.

period_multint (default=2)

Multiplicative factor to increase the period between restarts.

last_epochint (default=-1)

The index of the last valid epoch.

References

1

Ilya Loshchilov and Frank Hutter, 2017, “Stochastic Gradient Descent with Warm Restarts,”. “ICLR” https://arxiv.org/pdf/1608.03983.pdf

Methods

get_last_lr()

Return last computed learning rate by current scheduler.

load_state_dict(state_dict)

Loads the schedulers state.

print_lr(is_verbose, group, lr[, epoch])

Display the current learning rate.

state_dict()

Returns the state of the scheduler as a dict.

get_lr

step