skorch documentation

A scikit-learn compatible neural network library that wraps PyTorch.

Introduction

The goal of skorch is to make it possible to use PyTorch with sklearn. This is achieved by providing a wrapper around PyTorch that has an sklearn interface. In that sense, skorch is the spiritual successor to nolearn, but instead of using Lasagne and Theano, it uses PyTorch.

skorch does not re-invent the wheel, instead getting as much out of your way as possible. If you are familiar with sklearn and PyTorch, you don’t have to learn any new concepts, and the syntax should be well known. (If you’re not familiar with those libraries, it is worth getting familiarized.)

Additionally, skorch abstracts away the training loop, making a lot of boilerplate code obsolete. A simple net.fit(X, y) is enough. Out of the box, skorch works with many types of data, be it PyTorch Tensors, NumPy arrays, Python dicts, and so on. However, if you have other data, extending skorch is easy to allow for that.

Overall, skorch aims at being as flexible as PyTorch while having a clean interface as sklearn.

If you use skorch, please use this BibTeX entry:

@manual{skorch,
  author       = {Marian Tietz and Thomas J. Fan and Daniel Nouri and Benjamin Bossan and {skorch Developers}},
  title        = {skorch: A scikit-learn compatible neural network library that wraps PyTorch},
  month        = jul,
  year         = 2017,
  url          = {https://skorch.readthedocs.io/en/stable/}
}

User’s Guide

Installation

pip installation

To install with pip, run:

pip install -U skorch

We recommend to use a virtual environment for this.

From source

If you would like to use the must recent additions to skorch or help development, you should install skorch from source.

Using conda

You need a working conda installation. Get the correct miniconda for your system from here.

If you just want to use skorch, use:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda env create
source activate skorch
pip install .

If you want to help developing, run:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda env create
source activate skorch
pip install -e .

py.test  # unit tests
pylint skorch  # static code checks
Using pip

If you just want to use skorch, use:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install .

If you want to help developing, run:

git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install -r requirements-dev.txt
pip install -e .

py.test  # unit tests
pylint skorch  # static code checks

PyTorch

PyTorch is not covered by the dependencies, since the PyTorch version you need is dependent on your OS and device. For installation instructions for PyTorch, visit the PyTorch website. skorch officially supports the last four minor PyTorch versions, which currently are:

  • 1.4.0
  • 1.5.1
  • 1.6.0
  • 1.7.1

However, that doesn’t mean that older versions don’t work, just that they aren’t tested. Since skorch mostly relies on the stable part of the PyTorch API, older PyTorch versions should work fine.

In general, running this to install PyTorch should work (assuming CUDA 10.2):

# using conda:
conda install pytorch cudatoolkit==10.2 -c pytorch
# using pip
pip install torch

Quickstart

Training a model

Below, we define our own PyTorch Module and train it on a toy classification dataset using skorch NeuralNetClassifier:

import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X))
        return X


net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline

Since NeuralNetClassifier provides an sklearn-compatible interface, it is possible to put it into an sklearn Pipeline:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

What’s next?

Please visit the Tutorials page to explore additional examples on using skorch!

Tutorials

The following are examples and notebooks on how to use skorch.

NeuralNet

Using NeuralNet

NeuralNet and the derived classes are the main touch point for the user. They wrap the PyTorch Module while providing an interface that should be familiar for sklearn users.

Define your Module the same way as you always do. Then pass it to NeuralNet, in conjunction with a PyTorch criterion. Finally, you can call fit() and predict(), as with an sklearn estimator. The finished code could look something like this:

class MyModule(torch.nn.Module):
    ...

net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)
net.fit(X, y)
y_pred = net.predict(X_valid)

Let’s see what skorch did for us here:

  • wraps the PyTorch Module in an sklearn interface
  • converts numpy.ndarrays to PyTorch Tensors
  • abstracts away the fit loop
  • takes care of batching the data

You therefore have a lot less boilerplate code, letting you focus on what matters. At the same time, skorch is very flexible and can be extended with ease, getting out of your way as much as possible.

Initialization

In general, when you instantiate the NeuralNet instance, only the given arguments are stored. They are stored exactly as you pass them to NeuralNet. For instance, the module will remain uninstantiated. This is to make sure that the arguments you pass are not touched afterwards, which makes it possible to clone the NeuralNet instance, for instance.

Only when the fit() or initialize() method are called, are the different attributes of the net, such as the module, initialized. An initialized attribute’s name always ends on an underscore; e.g., the initialized module is called module_. (This is the same nomenclature as sklearn uses.) Thefore, you always know which attributes you set and which ones were created by NeuralNet.

The only exception is the history attribute, which is not set by the user.

Most important arguments and methods

A complete explanation of all arguments and methods of NeuralNet are found in the skorch API documentation. Here we focus on the main ones.

module

This is where you pass your PyTorch Module. Ideally, it should not be instantiated. Instead, the init arguments for your module should be passed to NeuralNet with the module__ prefix. E.g., if your module takes the arguments num_units and dropout, the code would look like this:

class MyModule(torch.nn.Module):
    def __init__(self, num_units, dropout):
        ...

net = NeuralNet(
    module=MyModule,
    module__num_units=100,
    module__dropout=0.5,
    criterion=torch.nn.NLLLoss,
)

It is, however, also possible to pass an instantiated module, e.g. a PyTorch Sequential instance.

Note that skorch does not automatically apply any nonlinearities to the outputs (except internally when determining the PyTorch NLLLoss, see below). That means that if you have a classification task, you should make sure that the final output nonlinearity is a softmax. Otherwise, when you call predict_proba(), you won’t get actual probabilities.

criterion

This should be a PyTorch (-compatible) criterion.

When you use the NeuralNetClassifier, the criterion is set to PyTorch NLLLoss by default. Furthermore, if you don’t change it loss to another criterion, NeuralNetClassifier assumes that the module returns probabilities and will automatically apply a logarithm on them (which is what NLLLoss expects).

For NeuralNetRegressor, the default criterion is PyTorch MSELoss.

After initializing the NeuralNet, the initialized criterion will stored in the criterion_ attribute.

optimizer

This should be a PyTorch optimizer, e.g. SGD. After initializing the NeuralNet, the initialized optimizer will stored in the optimizer_ attribute. During initialization you can define param groups, for example to set different learning rates for certain parameters. The parameters are selected by name with support for wildcards (globbing):

optimizer__param_groups=[
    ('embedding.*', {'lr': 0.0}),
    ('linear0.bias', {'lr': 1}),
]

Your use case may require an optimizer whose signature differs from a default PyTorch optimizer’s signature. In that case, you can define a custom function that reroutes the arguments as needed and pass it to the optimizer parameter:

# custom optimizer to encapsulate Adam
def make_lookahead(parameters, optimizer_cls, k, alpha, **kwargs):
    optimizer = optimizer_cls(parameters, **kwargs)
    return Lookahead(optimizer=optimizer, k=k, alpha=alpha)


net = NeuralNetClassifier(
        ...,
        optimizer=make_lookahead,
        optimizer__optimizer_cls=torch.optim.Adam,
        optimizer__weight_decay=1e-2,
        optimizer__k=5,
        optimizer__alpha=0.5,
        lr=1e-3)
lr

The learning rate. This argument exists for convenience, since it could also be set by optimizer__lr instead. However, it is used so often that we provided this shortcut. If you set both lr and optimizer__lr, the latter have precedence.

max_epochs

The maximum number of epochs to train with each fit() call. When you call fit(), the net will train for this many epochs, except if you interrupt training before the end (e.g. by using an early stopping callback or interrupt manually with ctrl+c).

If you want to change the number of epochs to train, you can either set a different value for max_epochs, or you call fit_loop() instead of fit() and pass the desired number of epochs explicitly:

net.fit_loop(X, y, epochs=20)
batch_size

This argument controls the batch size for iterator_train and iterator_valid at the same time. batch_size=128 is thus a convenient shortcut for explicitly typing iterator_train__batch_size=128 and iterator_valid__batch_size=128. If you set all three arguments, the latter two will have precedence.

train_split

This determines the NeuralNet’s internal train/validation split. By default, 20% of the incoming data is reserved for validation. If you set this value to None, all the data is used for training.

For more details, please look at dataset.

callbacks

For more details on the callback classes, please look at callbacks.

By default, NeuralNet and its subclasses start with a couple of useful callbacks. Those are defined in the get_default_callbacks() method and include, for instance, callbacks for measuring and printing model performance.

In addition to the default callbacks, you may provide your own callbacks. There are a couple of ways to pass callbacks to the NeuralNet instance. The easiest way is to pass a list of all your callbacks to the callbacks argument:

net = NeuralNet(
    module=MyModule,
    callbacks=[
        MyCallback1(...),
        MyCallback2(...),
    ],
)

Inside the NeuralNet instance, each callback will receive a separate name. Since we provide no name in the example above, the class name will taken, which will lead to a name collision in case of two or more callbacks of the same class. This is why it is better to initialize the callbacks with a list of tuples of name and callback instance, like this:

net = NeuralNet(
    module=MyModule,
    callbacks=[
        ('cb1', MyCallback1(...)),
        ('cb2', MyCallback2(...)),
    ],
)

This approach of passing a list of name, instance tuples should be familiar to users of sklearnPipelines and FeatureUnions.

An additonal advantage of this way of passing callbacks is that it allows to pass arguments to the callbacks by name (using the double-underscore notation):

net.set_params(callbacks__cb1__foo=123, callbacks__cb2__bar=456)

Use this, for instance, when trying out different callback parameters in a grid search.

Note: The user-defined callbacks are always called in the same order as they appeared in the list. If there are dependencies between the callbacks, the user has to make sure that the order respects them. Also note that the user-defined callbacks will be called after the default callbacks so that they can make use of the things provided by the default callbacks. The only exception is the default callback PrintLog, which is always called last.

warm_start

This argument determines whether each fit() call leads to a re-initialization of the NeuralNet or not. By default, when calling fit(), the parameters of the net are initialized, so your previous training progress is lost (consistent with the sklearn fit() calls). In contrast, with warm_start=True, each fit() call will continue from the most recent state.

device

As the name suggests, this determines which computation device should be used. If set to cuda, the incoming data will be transferred to CUDA before being passed to the PyTorch Module. The device parameter adheres to the general syntax of the PyTorch device parameter.

initialize()

As mentioned earlier, upon instantiating the NeuralNet instance, the net’s components are not yet initialized. That means, e.g., that the weights and biases of the layers are not yet set. This only happens after the initialize() call. However, when you call fit() and the net is not yet initialized, initialize() is called automatically. You thus rarely need to call it manually.

The initialize() method itself calls a couple of other initialization methods that are specific to each component. E.g., initialize_module() is responsible for initializing the PyTorch module. Therefore, if you have special needs for initializing the module, it is enough to override initialize_module(), you don’t need to override the whole initialize() method.

fit(X, y)

This is one of the main methods you will use. It contains everything required to train the model, be it batching of the data, triggering the callbacks, or handling the internal validation set.

In general, we assume there to be an X and a y. If you have more input data than just one array, it is possible for X to be a list or dictionary of data (see dataset). And if your task does not have an actual y, you may pass y=None.

If you fit with a PyTorch Dataset and don’t explicitly pass y, several components down the line might not work anymore, since sklearn sometimes requires an explicit y (e.g. for scoring). In general, PyTorch Datasets should work, though.

In addition to fit(), there is also the partial_fit() method, known from some sklearn estimators. partial_fit() allows you to continue training from your current status, even if you set warm_start=False. A further use case for partial_fit() is when your data does not fit into memory and you thus need to have several training steps.

Tip : skorch gracefully catches the KeyboardInterrupt exception. Therefore, during a training run, you can send a KeyboardInterrupt signal without the Python process exiting (typically, KeyboardInterrupt can be triggered by ctrl+c or, in a Jupyter notebook, by clicking Kernel -> Interrupt). This way, when your model has reached a good score before max_epochs have been reached, you can dynamically stop training.

predict(X) and predict_proba(X)

These methods perform an inference step on the input data and return numpy.ndarrays. By default, predict_proba() will return whatever it is that the module’s forward() method returns, cast to a numpy.ndarray. If forward() returns multiple outputs as a tuple, only the first output is used, the rest is discarded.

If the forward()-output can not be cast to a numpy.ndarray, or if you need access to all outputs in the multiple-outputs case, consider using either of forward() or forward_iter() methods to generate outputs from the module. Alternatively, you may directly call net.module_(X).

In case of NeuralNetClassifier, the predict() method tries to return the class labels by applying the argmax over the last axis of the result of predict_proba(). Obviously, this only makes sense if predict_proba() returns class probabilities. If this is not true, you should just use predict_proba().

score(X, y)

This method returns the mean accuracy on the given data and labels for classifiers and the coefficient of determination R^2 of the prediction for regressors. NeuralNet still has no score method. If you need it, you have to implement it yourself.

model persistence

In general there are different ways of saving and loading models, each with their own advantages and disadvantages. More details and usage examples can be found here: Saving and Loading.

If you would like to use pickle (the default way when using scikit-learn models), this is possible with skorch nets. This saves the whole net including hyperparameters etc. The advantage is that you can restore everything to exactly the state it was before. The disadvantage is it’s easier for code changes to break your old saves.

Additionally, it is possible to save and load specific attributes of the net, such as the module, optimizer, or history, by calling save_params() and load_params(). This is useful if you’re only interested in saving a particular part of your model, and is more robust to code changes.

Finally, it is also possible to use callbacks to save and load models, e.g. Checkpoint. Those should be used if you need to have your model saved or loaded at specific times, e.g. at the start or end of the training process.

Input data

Regular data

skorch supports numerous input types for data. Regular input types that should just work are numpy arrays, torch tensors, scipy sparse CSR matrices, and pandas DataFrames (see also DataFrameTransformer).

Typically, your task should involve an X and a y. If you’re dealing with a task that doesn’t require a target (say, training an autoencoder), you can just pass y=None. Make sure your loss function deals with this appropriately.

Datasets

Datasets are also supported, with the requirement that they should return exactly two items (X and y). For more information on that, take a look at the Dataset documentation.

Many PyTorch libraries, like torchvision, implement their own Datasets. These usually work seamlessly with skorch, as long as their __getitem__ methods return two outputs. In case they don’t, consider overriding the __getitem__ class and re-arranging the ouputs so that __getitem__ returns exactly two elements. If the original implementation returns more than two elements, take a look at the next section to get an idea how to deal with that.

Multiple input arguments

In some cases, the input actually consists of multiple inputs. E.g., in a text classification task, you might have an array that contains the integers representing the tokens for each sample, and another array containing the number of tokens of each sample. skorch has you covered here as well.

You could supply a list or tuple with all your inputs (net.fit([tokens, num_tokens], y)), but we actually recommend another approach. The best way is to pass the different arguments as a dictionary. Then the keys of that dictionary have to correspond to the argument names of your module’s forward method. Below is an example:

X_dict = {'tokens': tokens, 'num_tokens': num_tokens}

class MyModule(nn.Module):
    def forward(self, tokens, num_tokens):  # <- same names as in your dict
        ...

net = NeuralNet(MyModule, ...)
net.fit(X_dict, y)

As you can see, the forward method takes arguments with exactly the same name as the keys in the dictionary. This is how the different inputs are matched. To make this work with GridSearchCV, please use SliceDict.

Using a dict should cover most use cases that involve multiple inputs. However, it will fail if your inputs have different sizes. E.g., if your array of tokens has 1000 elements but your array of number of tokens has 2000 elements, this would fail. The main reason for this is batching: How can we know which elements of the two arrays belong in the same batch?

If your input consists of multiple inputs with different sizes, your best bet is to implement your own dataset class. That class should know how it deals with the different inputs, i.e. which elements belong to the same sample. Again, please refer to the Dataset section for more details.

Special arguments

In addition to the arguments explicitly listed for NeuralNet, there are some arguments with special prefixes, as shown below:

class MyModule(torch.nn.Module):
    def __init__(self, num_units, dropout):
        ...

net = NeuralNet(
    module=MyModule,
    module__num_units=100,
    module__dropout=0.5,
    criterion=torch.nn.NLLLoss,
    criterion__weight=weight,
    optimizer=torch.optim.SGD,
    optimizer__momentum=0.9,
)

Those arguments are used to initialize your module, criterion, etc. They are not fixed because we cannot know them in advance; in fact, you can define any parameter for your module or other components.

All special prefixes are stored in the prefixes_ class attribute of NeuralNet. Currently, they are:

  • module
  • iterator_train
  • iterator_valid
  • optimizer
  • criterion
  • callbacks
  • dataset

Subclassing NeuralNet

Apart from the NeuralNet base class, we provide NeuralNetClassifier, NeuralNetBinaryClassifier, and NeuralNetRegressor for typical classification, binary classification, and regressions tasks. They should work as drop-in replacements for sklearn classifiers and regressors.

The NeuralNet class is a little less opinionated about the incoming data, e.g. it does not determine a loss function by default. Therefore, if you want to write your own subclass for a special use case, you would typically subclass from NeuralNet.

skorch aims at making subclassing as easy as possible, so that it doesn’t stand in your way. For instance, all components (module, optimizer, etc.) have their own initialization method (initialize_module(), initialize_optimizer(), etc.). That way, if you want to modify the initialization of a component, you can easily do so.

Additonally, NeuralNet has a couple of get_* methods for when a component is retrieved repeatedly. E.g., get_loss() is called when the loss is determined. Below we show an example of overriding get_loss() to add L1 regularization to our total loss:

class RegularizedNet(NeuralNet):
    def __init__(self, *args, lambda1=0.01, **kwargs):
        super().__init__(*args, **kwargs)
        self.lambda1 = lambda1

    def get_loss(self, y_pred, y_true, X=None, training=False):
        loss = super().get_loss(y_pred, y_true, X=X, training=training)
        loss += self.lambda1 * sum([w.abs().sum() for w in self.module_.parameters()])
        return loss

Note

This example also regularizes the biases, which you typically don’t need to do.

It is possible to add your own criterion, module, or optimizer to your customized neural net class. You should follow a few rules when you do so:

  1. Set this attribute inside the corresponding method. E.g., when setting an optimizer, use initialize_optimizer() for that.
  2. Inside the initialization method, use get_params_for() (or, if dealing with an optimizer, get_params_for_optimizer()) to retrieve the arguments for the constructor.
  3. The attribute name should contain the substring "module" if it’s a module, "criterion" if a criterion, and "optimizer" if an optimizer. This way, skorch knows if a change in parameters (say, because set_params() was called) should trigger re-initialization.

When you follow these rules, you will make sure that your added components are amenable to set_params() and hence to things like grid search.

Here is an example of how this could look like in practice:

class MyNet(NeuralNet):
    def initialize_criterion(self, *args, **kwargs):
        super().initialize_criterion(*args, **kwargs)

        # add an additional criterion
        params = self.get_params_for('other_criterion')
        self.other_criterion_ = nn.BCELoss(**params)
        return self

    def initialize_module(self, *args, **kwargs):
        super().initialize_module(*args, **kwargs)

        # add an additional module called 'mymodule'
        params = self.get_params_for('mymodule')
        self.mymodule_ = MyModule(**params)
        return self

    def initialize_optimizer(self, *args, **kwargs):
        super().initialize_optimizer(*args, **kwargs)

        # add an additional optimizer called 'optimizer2' that is
        # responsible for 'mymodule'
        named_params = self.mymodule_.named_parameters()
        pgroups, params = self.get_params_for_optimizer('optimizer2', named_params)
        self.optimizer2_ = torch.optim.SGD(*pgroups, **params)
        return self

    ...  # additional changes


net = MyNet(
    ...,
    other_criterion__reduction='sum',
    mymodule__num_units=123,
    optimizer2__lr=0.1,
)
net.fit(X, y)

# set_params works
net.set_params(optimizer2__lr=0.05)
net.partial_fit(X, y)

# grid search et al. works
search = GridSearchCV(net, {'mymodule__num_units': [10, 50, 100]}, ...)
search.fit(X, y)

In this example, a new criterion, a new module, and a new optimizer were added. Of course, additional changes should be made to the net so that those new components are actually being used for something, but this example should illustrate how to start. Since the rules outlined above are being followed, we can use grid search on our customly defined components.

Callbacks

Callbacks provide a flexible way to customize the behavior of your NeuralNet training without the need to write subclasses.

You will often find callbacks writing to or reading from the history attribute. Therefore, if you would like to log the net’s behavior or do something based on the past behavior, consider using net.history.

This page will not explain all existing callbacks. For that, please look at skorch.callbacks.

Callback base class

The base class for each callback is Callback. If you would like to write your own callbacks, you should inherit from this class. A guide and practical example on how to write your own callbacks is shown in this notebook. In general, remember this:

  • They should inherit from the base class.
  • They should implement at least one of the on_ methods provided by the parent class (see below).
  • As argument, the methods first get the NeuralNet instance, and, where appropriate, the local data (e.g. the data from the current batch). The method should also have **kwargs in the signature for potentially unused arguments.

Callback methods to override

The following methods could potentially be overriden when implementing your own callbacks.

initialize()

If you have attributes that should be reset when the model is re-initialized, those attributes should be set in this method.

on_train_begin(net, X, y)

Called once at the start of the training process (e.g. when calling fit).

on_train_end(net, X, y)

Called once at the end of the training process.

on_epoch_begin(net, dataset_train, dataset_valid)

Called once at the start of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.

on_epoch_end(net, dataset_train, dataset_valid)

Called once at the end of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.

on_batch_begin(net, Xi, yi, training)

Called once before each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input. Also includes a bool indicating if this is a training batch or not.

on_batch_end(net, Xi, yi, training, loss, y_pred)

Called once after each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input.

on_grad_computed(net, named_parameters, Xi, yi)

Called once per batch after gradients have been computed but before an update step was performed. Gets the module parameters as additional input as well as the batch data. Useful if you want to tinker with gradients.

Setting callback parameters

You can set specific callback parameters using the ususal set_params interface on the network by using the callbacks__ prefix and the callback’s name. For example to change the scoring order of the train loss you can write this:

net = NeuralNet()
net.set_params(callbacks__train_loss__lower_is_better=False)

Changes will be applied on initialization and callbacks that are changed using set_params will be re-initialized.

The name you use to address the callback can be chosen during initialization of the network and defaults to the class name. If there is a conflict, the conflicting names will be made unique by appending a count suffix starting at 1, e.g. EpochScoring_1, EpochScoring_2, etc.

Deactivating callbacks

If you would like to (temporarily) deactivate a callback, you can do so by setting its parameter to None. E.g., if you have a callback called ‘my_callback’, you can deactivate it like this:

net = NeuralNet(
    module=MyModule,
        callbacks=[('my_callback', MyCallback())],
)
# now deactivate 'my_callback':
net.set_params(callbacks__my_callback=None)

This also works with default callbacks.

Deactivating callbacks can be especially useful when you do a parameter search (say with sklearn GridSearchCV). If, for instance, you use a callback for learning rate scheduling (e.g. via LRScheduler) and want to test its usefulness, you can compare the performance once with and once without the callback.

To completely disable all callbacks, including default callbacks, set callbacks="disable".

Scoring

skorch provides two callbacks that calculate scores by default, EpochScoring and BatchScoring. They work basically in the same way, except that EpochScoring calculates scores after each epoch and BatchScoring after each batch. Use the former if averaging of batch-wise scores is imprecise (say for AUC score) and the latter if you are very tight for memory.

In general, these scoring callbacks are useful when the default scores determined by the NeuralNet are not enough. They allow you to easily add new metrics to be logged during training. For an example of how to add a new score to your model, look at this notebook.

The first argument to both callbacks is name and should be a string. This determines the column name of the score shown by the PrintLog after each epoch.

Next comes the scoring parameter. For eager sklearn users, this should be familiar, since it works exactly the same as in sklearn GridSearchCV, RandomizedSearchCV, cross_val_score(), etc. For those who are unfamiliar, here is a short explanation:

  • If you pass a string, sklearn makes a look-up for a score with that name. Examples would be 'f1' and 'roc_auc'.
  • If you pass None, the model’s score method is used. By default, NeuralNet and its subclasses don’t provide a score method, but you can easily implement your own. If you do, it should take X and y (the target) as input and return a scalar as output.
  • Finally, you can pass a function/callable. In that case, this function should have the signature func(net, X, y) and return a scalar.

More on sklearn’s model evaluation can be found in this notebook.

The lower_is_better parameter determines whether lower scores should be considered as better (e.g. log loss) or worse (e.g. accuracy). This information is used to write a <name>_best value to the net’s history. E.g., if your score is f1 score and is called 'f1', you should set lower_is_better=False. The history will then contain an entry for 'f1', which is the score itself, and an entry for 'f1_best', which says whether this is the as of yet best f1 score.

on_train is used to indicate whether training or validation data should be used to determine the score. By default, it is set to validation.

Finally, you may have to provide your own target_extractor. This should be a function or callable that is applied to the target before it is passed to the scoring function. The main reason why we need this is that sometimes, the target is not of a form expected by sklearn and we need to process it before passing it on.

On top of the two described scoring callbacks, skorch also provides PassthroughScoring. This callback does not actually calculate any new scores. Instead it uses an existing score that is calculated for each batch (the train loss, for example) and determines the average of this score, which is then written to the epoch level of the net’s history. This is very useful if the score was already calculated and logged on the batch level and you’re only interested to see the averaged score on the epoch level.

For this callback, you only need to provide the name of the score in the history. Moreover, you may again specify if lower_is_better and if the score should be calculated on_train or not.

Note

Both BatchScoring and PassthroughScoring honor the batch size when calculating the average. This can make a difference when not all batch sizes are equal, which is typically the case because the last batch of an epoch contains fewer samples than the rest.

Checkpoint

The Checkpoint callback creates a checkpoint of your model after each epoch that met certain criteria. By default, the condition is that the validation loss has improved, however you may change this by specifying the monitor parameter. It can take three types of arguments:

  • None: The model is saved after each epoch;
  • string: The model checks whether the last entry in the model history for that key is truthy. This is useful in conjunction with scores determined by a scoring callback. They write a <score>_best entry to the history, which can be used for checkpointing. By default, the Checkpoint callback looks at 'valid_loss_best';
  • function or callable: In that case, the function should take the NeuralNet instance as sole input and return a bool as output.

To specify where and how your model is saved, change the arguments starting with f_:

  • f_params: to save model parameters
  • f_optimizer: to save optimizer state
  • f_history: to save training history
  • f_pickle: to pickle the entire model object.

Please refer to Saving and Loading for more information about restoring your network from a checkpoint.

Learning rate schedulers

The LRScheduler callback allows the use of the various learning rate schedulers defined in torch.optim.lr_scheduler, such as ReduceLROnPlateau, which allows dynamic learning rate reducing based on a given value to monitor, or CyclicLR, which cycles the learning rate between two boundaries with a constant frequency.

Here’s a network that uses a callback to set a cyclic learning rate:

from skorch.callbacks import LRScheduler
from torch.optim.lr_scheduler import CyclicLR

net = NeuralNet(
    module=MyModule,
    callbacks=[
        ('lr_scheduler',
         LRScheduler(policy=CyclicLR,
                     base_lr=0.001,
                     max_lr=0.01)),
    ],
)

As with other callbacks, you can use set_params to set parameters, and thus search learning rate scheduler parameters using GridSearchCV or similar. An example:

from sklearn.model_selection import GridSearchCV

search = GridSearchCV(
    net,
    param_grid={'callbacks__lr_scheduler__max_lr': [0.01, 0.1, 1.0]},
)

Dataset

This module contains classes and functions related to data handling.

CVSplit

This class is responsible for performing the NeuralNet’s internal cross validation. For this, it sticks closely to the sklearn standards. For more information on how sklearn handles cross validation, look here.

The first argument that CVSplit takes is cv. It works analogously to the cv argument from sklearn GridSearchCV, cross_val_score(), etc. For those not familiar, here is a short explanation of what you may pass:

  • None: Use the default 3-fold cross validation.
  • integer: Specifies the number of folds in a (Stratified)KFold,
  • float: Represents the proportion of the dataset to include in the validation split (e.g. 0.2 for 20%).
  • An object to be used as a cross-validation generator.
  • An iterable yielding train, validation splits.

Furthermore, CVSplit takes a stratified argument that determines whether a stratified split should be made (only makes sense for discrete targets), and a random_state argument, which is used in case the cross validation split has a random component.

One difference to sklearn’s cross validation is that skorch makes only a single split. In sklearn, you would expect that in a 5-fold cross validation, the model is trained 5 times on the different combination of folds. This is often not desirable for neural networks, since training takes a lot of time. Therefore, skorch only ever makes one split.

If you would like to have all splits, you can still use skorch in conjunction with the sklearn functions, as you would do with any other sklearn-compatible estimator. Just remember to set train_split=None, so that the whole dataset is used for training. Below is shown an example of making out-of-fold predictions with skorch and sklearn:

net = NeuralNetClassifier(
    module=MyModule,
    train_split=None,
)

from sklearn.model_selection import cross_val_predict

y_pred = cross_val_predict(net, X, y, cv=5)

Dataset

In PyTorch, we have the concept of a Dataset and a DataLoader. The former is purely the container of the data and only needs to implement __len__() and __getitem__(<int>). The latter does the heavy lifting, such as sampling, shuffling, and distributed processing.

skorch uses the PyTorch DataLoaders by default. skorch supports PyTorch’s Dataset when calling fit() or partial_fit(). Details on how to use PyTorch’s Dataset with skorch, can be found in How do I use a PyTorch Dataset with skorch?. In order to support other data formats, we provide our own Dataset class that is compatible with:

Note that currently, sparse matrices are cast to dense arrays during batching, given that PyTorch support for sparse matrices is still very incomplete. If you would like to prevent that, you need to override the transform method of Dataset.

In addition to the types above, you can pass dictionaries or lists of one of those data types, e.g. a dictionary of numpy.ndarrays. When you pass dictionaries, the keys of the dictionaries are used as the argument name for the forward() method of the net’s module. Similarly, the column names of pandas DataFrames are used as argument names. The example below should illustrate how to use this feature:

import numpy as np
import torch
import torch.nn.functional as F

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.dense_a = torch.nn.Linear(10, 100)
        self.dense_b = torch.nn.Linear(20, 100)
        self.output = torch.nn.Linear(200, 2)

    def forward(self, key_a, key_b):
        hid_a = F.relu(self.dense_a(key_a))
        hid_b = F.relu(self.dense_b(key_b))
        concat = torch.cat((hid_a, hid_b), dim=1)
        out = F.softmax(self.output(concat))
        return out

net = NeuralNetClassifier(MyModule)

X = {
    'key_a': np.random.random((1000, 10)).astype(np.float32),
    'key_b': np.random.random((1000, 20)).astype(np.float32),
}
y = np.random.randint(0, 2, size=1000)

net.fit(X, y)

Note that the keys in the dictionary X exactly match the argument names in the forward() method. This way, you can easily work with several different types of input features.

The Dataset from skorch makes the assumption that you always have an X and a y, where X represents the input data and y the target. However, you may leave y=None, in which case Dataset returns a dummy variable.

Dataset applies a transform final transform on the data before passing it on to the PyTorch DataLoader. By default, it replaces y by a dummy variable in case it is None. If you would like to apply your own transformation on the data, you should subclass Dataset and override the transform() method, then pass your custom class to NeuralNet as the dataset argument.

Saving and Loading

General approach

skorch provides several ways to persist your model. First it is possible to store the model using Python’s pickle function. This saves the whole model, including hyperparameters. This is useful when you don’t want to initialize your model before loading its parameters, or when your NeuralNet is part of an sklearn Pipeline:

net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)

model = Pipeline([
    ('my-features', get_features()),
    ('net', net),
])
model.fit(X, y)

# saving
with open('some-file.pkl', 'wb') as f:
    pickle.dump(model, f)

# loading
with open('some-file.pkl', 'rb') as f:
    model = pickle.load(f)

The disadvantage of pickling is that if your underlying code changes, unpickling might raise errors. Also, some Python code (e.g. lambda functions) cannot be pickled.

For this reason, we provide a second method for persisting your model. To use it, call the save_params() and load_params() method on NeuralNet. Under the hood, this saves the module’s state_dict, i.e. only the weights and biases of the module. This is more robust to changes in the code but requires you to initialize a NeuralNet to load the parameters again:

net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)

model = Pipeline([
    ('my-features', get_features()),
    ('net', net),
])
model.fit(X, y)

net.save_params(f_params='some-file.pkl')

new_net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)
new_net.initialize()  # This is important!
new_net.load_params(f_params='some-file.pkl')

In addition to saving the model parameters, the history and optimizer state can be saved by including the f_history and f_optimizer keywords to save_params() and load_params() on NeuralNet. This feature can be used to continue training:

net = NeuralNet(
    module=MyModule
    criterion=torch.nn.NLLLoss,
)

net.fit(X, y, epochs=2) # Train for 2 epochs

net.save_params(
    f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')

new_net = NeuralNet(
    module=MyModule
    criterion=torch.nn.NLLLoss,
)
new_net.initialize() # This is important!
new_net.load_params(
    f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')

new_net.fit(X, y, epochs=2) # Train for another 2 epochs

Note

In order to use this feature, the history must only contain JSON encodable Python data structures. Numpy and PyTorch types should not be in the history.

Note

save_params() does not store learned attributes on the net. E.g., skorch.classifier.NeuralNetClassifier remembers the classes it encountered during training in the classes_ attribute. This attribute will be missing after load_params(). Therefore, if you need it, you should pickle.dump() the whole net.

Using callbacks

skorch provides Checkpoint, TrainEndCheckpoint, and LoadInitState callbacks to handle saving and loading models during training. To demonstrate these features, we generate a dataset and create a simple module:

import numpy as np
from sklearn.datasets import make_classification
from torch import nn

X, y = make_classification(1000, 10, n_informative=5, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Sequential):
    def __init__(self, num_units=10):
        super().__init__(
            nn.Linear(10, num_units),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(num_units, 10),
            nn.Linear(10, 2),
            nn.Softmax(dim=-1)
        )

Then we create two different checkpoint callbacks and configure them to save the model parameters, optimizer, and history into a directory named 'exp1':

# First run

from skorch.callbacks import Checkpoint, TrainEndCheckpoint
from skorch import NeuralNetClassifier

cp = Checkpoint(dirname='exp1')
train_end_cp = TrainEndCheckpoint(dirname='exp1')
net = NeuralNetClassifier(
    MyModule, lr=0.5, callbacks=[cp, train_end_cp]
)

_ = net.fit(X, y)

# prints
  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        0.6200       0.8209        0.4765     +  0.0232
      2        0.3644       0.8557        0.3474     +  0.0238
      3        0.2875       0.8806        0.3201     +  0.0214
      4        0.2514       0.8905        0.3080     +  0.0237
      5        0.2333       0.9154        0.2844     +  0.0203
      6        0.2177       0.9403        0.2164     +  0.0215
      7        0.2194       0.9403        0.2159     +  0.0220
      8        0.2027       0.9403        0.2299        0.0202
      9        0.1864       0.9254        0.2313        0.0196
     10        0.2024       0.9353        0.2333        0.0221

By default, Checkpoint observes valid_loss metric and saves the model when the metric improves. This is indicated by the + mark in the cp column of the logs.

On our first run, the validation loss did not improve after the 7th epoch. We can lower the learning rate and continue training from this checkpoint by using LoadInitState:

from skorch.callbacks import LoadInitState

cp = Checkpoint(dirname='exp1')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
    MyModule, lr=0.1, callbacks=[cp, load_state]
)

_ = net.fit(X, y)

# prints

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      8        0.1939       0.9055        0.2626     +  0.0238
      9        0.2055       0.9353        0.2031     +  0.0239
     10        0.1992       0.9453        0.2101        0.0182
     11        0.2033       0.9453        0.1947     +  0.0211
     12        0.1825       0.9104        0.2515        0.0185
     13        0.2010       0.9453        0.1927     +  0.0187
     14        0.1508       0.9453        0.1952        0.0198
     15        0.1679       0.9502        0.1905     +  0.0181
     16        0.1516       0.9453        0.1864     +  0.0192
     17        0.1576       0.9453        0.1804     +  0.0184

The LoadInitState callback is executed once in the beginning of the training procedure and initializes model, history, and optimizer parameters from a specified checkpoint (if it exists). In our case, the previous checkpoint was created at the end of epoch 7, so the second run resumes from epoch 8. With a lower learning rate, the validation loss was able to improve!

Notice that in the first run we included a TrainEndCheckpoint in the list of callbacks. As its name suggests, this callback creates a checkpoint at the end of training. As before, we can pass it to LoadInitState to continue training:

cp_from_final = Checkpoint(dirname='exp1', fn_prefix='from_train_end_')
load_state = LoadInitState(train_end_cp)
net = NeuralNetClassifier(
    MyModule, lr=0.1, callbacks=[cp_from_final, load_state]
)

_ = net.fit(X, y)

# prints

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
     11        0.1663       0.9453        0.2166     +  0.0282
     12        0.1880       0.9403        0.2237        0.0178
     13        0.1813       0.9353        0.1993     +  0.0161
     14        0.1744       0.9353        0.1955     +  0.0150
     15        0.1538       0.9303        0.2053        0.0077
     16        0.1473       0.9403        0.1947     +  0.0078
     17        0.1563       0.9254        0.1989        0.0074
     18        0.1558       0.9403        0.1877     +  0.0075
     19        0.1534       0.9254        0.2318        0.0074
     20        0.1779       0.9453        0.1814     +  0.0074

In this run, training started at epoch 11, continuing from the end of the first run which ended at epoch 10. We created a new Checkpoint callback with fn_prefix set to 'from_train_end_' to prefix the saved filenames with 'from_train_end_' to make sure this checkpoint does not override the checkpoint from the previous run.

Since our MyModule class allows num_units to be adjusted, we can start a new experiment by changing the dirname:

cp = Checkpoint(dirname='exp2')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
    MyModule, lr=0.5,
    callbacks=[cp, load_state],
    module__num_units=20,
)

_ = net.fit(X, y)

# prints

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        0.5256       0.8856        0.3624     +  0.0181
      2        0.2956       0.8756        0.3416     +  0.0222
      3        0.2280       0.9453        0.2299     +  0.0211
      4        0.1948       0.9303        0.2136     +  0.0232
      5        0.1800       0.9055        0.2696        0.0223
      6        0.1605       0.9403        0.1906     +  0.0190
      7        0.1594       0.9403        0.2027        0.0184
      8        0.1319       0.9303        0.1910        0.0220
      9        0.1558       0.9254        0.1923        0.0189
     10        0.1432       0.9303        0.2219        0.0192

This stores the model into the 'exp2' directory. Since this is the first run, the LoadInitState callback does not do anything. If we were to run the above script again, the LoadInitState callback will load the model from the checkpoint.

In the run above, the last checkpoint was created at epoch 6, we can load this checkpoint to predict with it:

net = NeuralNetClassifier(
    MyModule, lr=0.5, module__num_units=20,
)
net.initialize()
net.load_params(checkpoint=cp)

y_pred = net.predict(X)

In this case, it is important to initialize the neural net before running NeuralNet.load_params().

History

A NeuralNet object logs training progress internally using a History object, stored in the history attribute. Among other use cases, history is used to print the training progress after each epoch:

net.fit(X, y)

# prints
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.7111       0.5100        0.6894  0.1345
      2        0.6928       0.5500        0.6803  0.0608
      3        0.6833       0.5650        0.6741  0.0620
      4        0.6763       0.5850        0.6674  0.0594

All this information (and more) is stored in and can be accessed through net.history. It is thus best practice to make use of history for storing training-related data.

In general, History works like a list of dictionaries, where each item in the list corresponds to one epoch, and each key of the dictionary to one column. Thus, if you would like to access the 'train_loss' of the last epoch, you can call net.history[-1]['train_loss']. To make the history more accessible, though, it is possible to just pass the indices separated by a comma: net.history[-1, 'train_loss'].

Moreover, History stores the results from each individual batch under the batches key during each epoch. So to get the train loss of the 3rd batch of the 7th epoch, use net.history[7, 'batches', 3, 'train_loss'].

Here are some examples showing how to index history:

# history of a fitted neural net
history = net.history
# get current epoch, a dict
history[-1]
# get train losses from all epochs, a list of floats
history[:, 'train_loss']
# get train and valid losses from all epochs, a list of tuples
history[:, ('train_loss', 'valid_loss')]
# get current batches, a list of dicts
history[-1, 'batches']
# get latest batch, a dict
history[-1, 'batches', -1]
# get train losses from current batch, a list of floats
history[-1, 'batches', :, 'train_loss']
# get train and valid losses from current batch, a list of tuples
history[-1, 'batches', :, ('train_loss', 'valid_loss')]

As History essentially is a list of dictionaries, you can also write to it as if it were a list of dictionaries. Here too, skorch provides some convenience functions to make life easier. First there is new_epoch(), which will add a new epoch dictionary to the end of the list. Also, there is new_batch() for adding new batches to the current epoch.

To add a new item to the current epoch, use history.record('foo', 123). This will set the value 123 for the key foo of the current epoch. To write a value to the current batch, use history.record_batch('bar', 456). Below are some more examples:

# history of a fitted neural net
history = net.history
# add new epoch row
history.new_epoch()
# add an entry to current epoch
history.record('my-score', 123)
# add a batch row to the current epoch
history.new_batch()
# add an entry to the current batch
history.record_batch('my-batch-score', 456)
# overwrite entry of current batch
history.record_batch('my-batch-score', 789)

Toy

This module contains helper functions and classes that allow you to prototype quickly or that can be used for writing tests.

MLPModule

MLPModule is a simple PyTorch Module that implements a multi-layer perceptron. It allows to indicate the number of input, hidden, and output units, as well as the non-linearity and use of dropout. You can use this module directly in conjunction with NeuralNet.

Additionally, the functions make_classifier(), make_binary_classifier(), and make_regressor() can be used to return a MLPModule with the defaults adjusted for use in multi-class classification, binary classification, and regression, respectively.

Helper

This module provides helper functions and classes for the user. They make working with skorch easier but are not used by skorch itself.

SliceDict

A SliceDict is a wrapper for Python dictionaries that makes them behave a little bit like numpy.ndarrays. That way, you can slice your dictionary across values, len() will show the length of the arrays and not the number of keys, and you get a shape attribute. This is useful because if your data is in a dict, you would normally not be able to use sklearn GridSearchCV and similar things; with SliceDict, this works.

SliceDataset

A SliceDataset is a wrapper for PyTorch Datasets that makes them behave a little bit like numpy.ndarrays. That way, you can slice your dataset with lists and arrays, and you get a shape attribute. These properties are useful because if your data is in a dataset, you would normally not be able to use sklearn GridSearchCV and similar things; with SliceDataset, this works.

Note that SliceDataset can only ever return one of the values returned by the dataset. Typically, this will be either the X or the y value. Therefore, if you want to wrap both X and y, you should create two instances of SliceDataset, one for X (with argument idx=0, the default) and one for y (with argument idx=1):

ds = MyCustomDataset()
X_sl = SliceDataset(ds, idx=0)  # idx=0 is the default
y_sl = SliceDataset(ds, idx=1)
gs.fit(X_sl, y_sl)

Command line interface helpers

Often you want to wrap up your experiments by writing a small script that allows others to reproduce your work. With the help of skorch and the fire library, it becomes very easy to write command line interfaces without boilerplate. All arguments pertaining to skorch or its PyTorch module are immediately available as command line arguments, without the need to write a custom parser. If docstrings in the numpydoc specification are available, there is also an comprehensive help for the user. Overall, this allows you to make your work reproducible without the usual hassle.

There is an example in the skorch repository that shows how to use the CLI tools. Below is a snippet that shows the output created by the help function without writing a single line of argument parsing:

$ python examples/cli/train.py pipeline --help

<SelectKBest> options:
   --select__score_func : callable
     Function taking two arrays X and y, and returning a pair of arrays
     (scores, pvalues) or a single array with scores.
     Default is f_classif (see below "See also"). The default function only
     works with classification tasks.
   --select__k : int or "all", optional, default=10
     Number of top features to select.
     The "all" option bypasses selection, for use in a parameter search.

...

<NeuralNetClassifier> options:
   --net__module : torch module (class or instance)
     A PyTorch :class:`~torch.nn.Module`. In general, the
     uninstantiated class should be passed, although instantiated
     modules will also work.
   --net__criterion : torch criterion (class, default=torch.nn.NLLLoss)
     Negative log likelihood loss. Note that the module should return
     probabilities, the log is applied during ``get_loss``.
   --net__optimizer : torch optim (class, default=torch.optim.SGD)
     The uninitialized optimizer (update rule) used to optimize the
     module
   --net__lr : float (default=0.01)
     Learning rate passed to the optimizer. You may use ``lr`` instead
     of using ``optimizer__lr``, which would result in the same outcome.
   --net__max_epochs : int (default=10)
     The number of epochs to train for each ``fit`` call. Note that you
     may keyboard-interrupt training at any time.
   --net__batch_size : int (default=128)
     ...
   --net__verbose : int (default=1)
     Control the verbosity level.
   --net__device : str, torch.device (default='cpu')
     The compute device to be used. If set to 'cuda', data in torch
     tensors will be pushed to cuda tensors before being sent to the
     module.

<MLPClassifier> options:
   --net__module__hidden_units : int (default=10)
     Number of units in hidden layers.
   --net__module__num_hidden : int (default=1)
     Number of hidden layers.
   --net__module__nonlin : torch.nn.Module instance (default=torch.nn.ReLU())
     Non-linearity to apply after hidden layers.
   --net__module__dropout : float (default=0)
     Dropout rate. Dropout is applied between layers.
Installation

To use this functionality, you need some further libraries that are not part of skorch, namely fire and numpydoc. You can install them thusly:

pip install fire numpydoc
Usage

When you write your own script, only the following bits need to be added:

import fire
from skorch.helper import parse_args

# your model definition and data fetching code below
...

def main(**kwargs):
    X, y = get_data()
    my_model = get_model()

    # important: wrap the model with the parsed arguments
    parsed = parse_args(kwargs)
    my_model = parsed(my_model)

    my_model.fit(X, y)


if __name__ == '__main__':
    fire.Fire(main)

This even works if your neural net is part of an sklearn pipeline, in which case the help extends to all other estimators of your pipeline.

In case you would like to change some defaults for the net (e.g. using a batch_size of 256 instead of 128), this is also possible. You should have a dictionary containing your new defaults and pass it as an additional argument to parse_args:

my_defaults = {'batch_size': 128, 'module__hidden_units': 30}

def main(**kwargs):
    ...
    parsed = parse_args(kwargs, defaults=my_defaults)
    my_model = parsed(my_model)

This will update the displayed help to your new defaults, as well as set the parameters on the net or pipeline for you. However, the arguments passed via the commandline have precedence. Thus, if you additionally pass --batch_size 512 to the script, batch size will be 512.

Restrictions

Almost all arguments should work out of the box. Therefore, you get command line arguments for the number of epochs, learning rate, batch size, etc. for free. Moreover, you can access the module parameters with the double-underscore notation as usual with skorch (e.g. --module__num_units 100). This should cover almost all common cases.

Parsing command line arguments that are non-primitive Python objects is more difficult, though. skorch’s custom parsing should support normal Python types and simple custom objects, e.g. this works: --module__nonlin 'torch.nn.RReLU(0.1, upper=0.4)'. More complex parsing might not work. E.g., it is currently not possible to add new callbacks through the command line (but you can modify existing ones as usual).

REST Service

In this section we’ll take the RNN sentiment classifer from the example Predicting sentiment on the IMDB dataset and use it to demonstrate how to easily expose your PyTorch module on the web using skorch and another library called Palladium.

With Palladium, you define the Palladium dataset, the model, and Palladium provides the framework to fit, test, and serve your model on the web. Palladium comes with its own documentation and a tutorial, which you may want to check out to learn more about what you can do with it.

The way to make the dataset and model known to Palladium is through its configuration file. Here’s the part of the configuration that defines the dataset and model:

{
    'dataset_loader_train': {
        '__factory__': 'model.DatasetLoader',
        'path': 'aclImdb/train/',
    },

    'dataset_loader_test': {
        '__factory__': 'model.DatasetLoader',
        'path': 'aclImdb/test/',
    },

    'model': {
        '__factory__': 'model.create_pipeline',
        'use_cuda': True,
    },

    'model_persister': {
        '__factory__': 'palladium.persistence.File',
        'path': 'rnn-model-{version}',
    },

    'scoring': 'accuracy',
}

You can save this configuration as palladium-config.py.

The dataset_loader_train and dataset_loader_test entries define where the data comes from. They refer to a Python class defined inside the model module. Let’s create a file and call it model.py, put it in the same directory as the configuration file. We’ll start off with defining the dataset loader:

import os
from urllib.request import urlretrieve
import tarfile

import numpy as np
from palladium.interfaces import DatasetLoader as IDatasetLoader
from sklearn.datasets import load_files

DATA_URL = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
DATA_FN = DATA_URL.rsplit('/', 1)[1]


def download():
    if not os.path.exists('aclImdb'):
        # unzip data if it does not exist
        if not os.path.exists(DATA_FN):
            urlretrieve(DATA_URL, DATA_FN)
        with tarfile.open(DATA_FN, 'r:gz') as f:
            f.extractall()


class DatasetLoader(IDatasetLoader):
    def __init__(self, path='aclImdb/train/'):
        self.path = path

    def __call__(self):
        download()
        dataset = load_files(self.path, categories=['pos', 'neg'])
        X, y = dataset['data'], dataset['target']
        X = np.asarray([x.decode() for x in X])  # decode from bytes
        return X, y

The most interesting bit here is that our Palladium DatasetLoader defines a __call__ method that will return the data and the target (X and y). Easy. Note that in the configuration file, we refer to our DatasetLoader twice, once for the training set and once for the test set.

Our configuration also refers to a function create_pipeline which we’ll create next:

from dstoolbox.transformers import Padder2d
from dstoolbox.transformers import TextFeaturizer
from sklearn.pipeline import Pipeline
from skorch import NeuralNetClassifier
import torch


def create_pipeline(
    vocab_size=1000,
    max_len=50,
    use_cuda=False,
    **kwargs
):
    return Pipeline([
        ('to_idx', TextFeaturizer(max_features=vocab_size)),
        ('pad', Padder2d(max_len=max_len, pad_value=vocab_size, dtype=int)),
        ('net', NeuralNetClassifier(
            RNNClassifier,
            device=('cuda' if use_cuda else 'cpu'),
            max_epochs=5,
            lr=0.01,
            optimizer=torch.optim.RMSprop,
            module__vocab_size=vocab_size,
            **kwargs,
        ))
    ])

You’ve noticed that this function’s job is to create the model and return it. Here, we’re defining a pipeline that wraps skorch’s NeuralNetClassifier, which in turn is a wrapper around our PyTorch module, as it’s defined in the predicting sentiment tutorial. We’ll also add the RNNClassifier to model.py:

from torch import nn
F = nn.functional


class RNNClassifier(nn.Module):
    def __init__(
        self,
        embedding_dim=128,
        rec_layer_type='lstm',
        num_units=128,
        num_layers=2,
        dropout=0,
        vocab_size=1000,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.rec_layer_type = rec_layer_type.lower()
        self.num_units = num_units
        self.num_layers = num_layers
        self.dropout = dropout

        self.emb = nn.Embedding(
            vocab_size + 1, embedding_dim=self.embedding_dim)

        rec_layer = {'lstm': nn.LSTM, 'gru': nn.GRU}[self.rec_layer_type]
        # We have to make sure that the recurrent layer is batch_first,
        # since sklearn assumes the batch dimension to be the first
        self.rec = rec_layer(
            self.embedding_dim, self.num_units,
            num_layers=num_layers, batch_first=True,
            )

        self.output = nn.Linear(self.num_units, 2)

    def forward(self, X):
        embeddings = self.emb(X)
        # from the recurrent layer, only take the activities from the
        # last sequence step
        if self.rec_layer_type == 'gru':
            _, rec_out = self.rec(embeddings)
        else:
            _, (rec_out, _) = self.rec(embeddings)
        rec_out = rec_out[-1]  # take output of last RNN layer
        drop = F.dropout(rec_out, p=self.dropout)
        # Remember that the final non-linearity should be softmax, so
        # that our predict_proba method outputs actual probabilities!
        out = F.softmax(self.output(drop), dim=-1)
        return out

You can find the full contents of the model.py file in the skorch/examples/rnn_classifer folder of skorch’s source code.

Now with dataset and model in place, it’s time to try Palladium out. You can install Palladium and another dependency we use with pip install palladium dstoolbox.

From within the directory that contains model.py and palladium-config.py now run the following command:

PALLADIUM_CONFIG=palladium-config.py pld-fit --evaluate

You should see output similar to this:

INFO:palladium:Loading data...
INFO:palladium:Loading data done in 0.607 sec.
INFO:palladium:Fitting model...
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.7679       0.5008        0.7617  3.1300
      2        0.6385       0.7100        0.5840  3.1247
      3        0.5430       0.7438        0.5518  3.1317
      4        0.4736       0.7480        0.5424  3.1373
      5        0.4253       0.7448        0.5832  3.1433
INFO:palladium:Fitting model done in 29.060 sec.
DEBUG:palladium:Evaluating model on train set...
INFO:palladium:Train score: 0.83068
DEBUG:palladium:Evaluating model on train set done in 6.743 sec.
DEBUG:palladium:Evaluating model on test set...
INFO:palladium:Test score:  0.75428
DEBUG:palladium:Evaluating model on test set done in 6.476 sec.
INFO:palladium:Writing model...
INFO:palladium:Writing model done in 0.694 sec.
INFO:palladium:Wrote model with version 1.

Congratulations, you’ve trained your first model with Palladium! Note that in the output you see a train score (accuracy) of 0.83 and a test score of about 0.75. These refer to how well your model did on the training set (defined by dataset_loader_train in the configuration) and on the test set (dataset_loader_test).

You’re ready to now serve the model on the web. Add this piece of configuration to the palladium-config.py configuration file (and make sure it lives within the outermost brackets:

{
    # ...

    'predict_service': {
        '__factory__': 'palladium.server.PredictService',
        'mapping': [
            ('text', 'str'),
        ],
        'predict_proba': True,
        'unwrap_sample': True,
    },

    # ...
}

With this piece of information inside the configuration, we’re ready to launch the web server using:

PALLADIUM_CONFIG=palladium-config.py pld-devserver

You can now try out the web service at this address: http://localhost:5000/predict?text=this+movie+was+brilliant

You should see a JSON string returned that looks something like this:

{
    "metadata": {"error_code": 0, "status": "OK"},
    "result": [0.326442807912827, 0.673557221889496],
}

The result entry has the probabilities. Our model assigns 67% probability to the sentence “this movie was brilliant” to be positive. By the way, the skorch tutorial itself has tips on how to improve this model.

The take away is Palladium helps you reduce the boilerplate code that’s needed to get your machine learning project started. Palladium has routines to fit, test, and serve models so you don’t have to worry about that, and you can concentrate on the actual machine learning part. Configuration and code are separated with Palladium, which helps organize your experiments and work on ideas in parallel. Check out the Palladium documentation for more.

Parallelism

Skorch supports distributing work among a cluster of workers via dask.distributed. In this section we’ll describe how to use Dask to efficiently distribute a grid search or a randomized search on hyperparamerers across multiple GPUs and potentially multiple hosts.

Let’s assume that you have two GPUs that you want to run a hyper parameter search on.

The key here is using the CUDA environment variable CUDA_VISIBLE_DEVICES to limit which devices are visible to our CUDA application. We’ll set up Dask workers that, using this environment variable, each see one GPU only. On the PyTorch side, we’ll have to make sure to set the device to cuda when we initialize the NeuralNet class.

Let’s run through the steps. First, install Dask and dask.distrubted:

pip install dask distributed

Next, assuming you have two GPUs on your machine, let’s start up a Dask scheduler and two Dask workers. Make sure the Dask workers are started up in the right environment, that is, with access to all packages required to do the work:

dask-scheduler
CUDA_VISIBLE_DEVICES=0 dask-worker 127.0.0.1:8786 --nthreads 1
CUDA_VISIBLE_DEVICES=1 dask-worker 127.0.0.1:8786 --nthreads 1

In your code, use joblib’s parallel_backend() context manager to activate the Dask backend when you run grid searches and the like. Also instantiate a dask.distributed.Client to point to the Dask scheduler that you want to use. Let’s see how this could look like:

from dask.distributed import Client
from joblib import parallel_backend

client = Client('127.0.0.1:8786')

X, y = load_my_data()
net = get_that_net()

gs = GridSearchCV(
    net,
    param_grid={'lr': [0.01, 0.03]},
    scoring='accuracy',
    )
with parallel_backend('dask'):
    gs.fit(X, y)
print(gs.cv_results_)

You can also use Palladium to do the job. An example is included in the source in the examples/rnn_classifier folder. Change in there and run the following command, after having set up your Dask workers:

PALLADIUM_CONFIG=palladium-config.py,dask-config.py pld-grid-search

Performance

Since skorch provides extra functionality on top of a pure PyTorch training code, it is expected that it will add an overhead to the total runtime. For typical workloads, this overhead should be unnoticeable.

In a few situations, skorch’s extra functionality may add significant overhead. This is especially the case when the amount of data and the neural net are relatively small. The reason is that typically, most time is spent on the forward, backward, and parameter update calls. When those are really fast, the skorch overhead will get noticed.

There are, however, a few things that can be done to reduce the skorch overhead. We will focus on accelerating the training process, where the overhead should be largest. Below, some mitigations are described, including the potential downsides.

First make sure that there is any significant slowdown

Neural nets are notoriously slow to train. Therefore, if your training takes a lot of time, that doesn’t automatically mean that the skorch overhead is at fault. Maybe the training would take the same time without skorch. If you have some measurements about training the same model without skorch, first make sure that this points to skorch being the culprit before trying to optimize using the mitigations described below. If it turns out skorch is not the culprit, look into optimizing the performance of PyTorch code in general.

Many people use skorch for hyper-parameter search. Remember that this implies fitting the model repeatedly, thus a long run time is expected. E.g. if you run a grid search on two hyper-parameters, each with 10 variants, and 5 splits, there will actually be 10 x 10 x 5 fit calls, so expect the process to take approximately 500 times as long as a single model fit. Increase the verbosity on the grid search to get a better idea on the progress (e.g. GridSerachCV(..., verbose=3)).

Turning off verbosity

By default, skorch produces a print log of the training progress. This is useful for checking the training progress, monitor overfitting, etc. If you don’t need these diagnostics, you can turn them off via the verbose parameter. This way, printing is deactivated, saving time on i/o. You can still access the diagnostics through the history attribute after training has finished.

net = NeuralNet(..., verbose=0)  # turn off verbosity
net.fit(X, y)
train_loss = net.history[..., 'train_loss']  # access history as usual

Disabling callbacks all together

If you don’t need any callbacks at all, turning them off can be potential time saver. Callbacks present the most significant “extra” that skorch provides over pure PyTorch, hence they might add a lot of overhead for small workloads. By turning them off, you lose their functionality, though. It’s up to you to determine if that’s a worthwhile trade-off or not. For instance, in contrast to just turning down verbosity, you will no longer have access to useful diagnostics in the history attribute.

# skorch version 0.10 or later:
net = NeuralNet(..., callbacks='disable')
net.fit(X, y)
print(net.history)  # no longer contains useful diagnostics

# skorch version 0.9 or earlier
net = NeuralNet(...)
net.initialize()
net.callbacks_ = []  # manually remove all callbacks
net.fit(X, y)

Instead of turning off all callbacks, you can also turn off specific callbacks, including default callbacks. This way, you can decide which ones to keep and which ones to get rid off. Typically, callbacks that calculate some kind of metric tend to be slow.

# deactivate callbacks that determine train and valid loss after each epoch
net = NeuralNet(..., callbacks__train_loss=None, callbacks__valid_loss=None)
net.fit(X, y)
print(net.history)  # no longer contains 'train_loss' and 'valid_loss' entries

Prepare the Dataset

skorch can deal with a number of different input data types. This is very convenient, as it removes the necessity for the user to deal with them, but it also adds a small overhead. Therefore, if you can prepare your data so that it’s already contained in an appropriate torch.utils.data.Dataset, this check can be skipped.

X, y = ...  # let's assume that X and y are numpy arrays
net = NeuralNet(...)

# normal way: let skorch figure out how to create the Dataset
net.fit(X, y)

# faster way: prepare Dataset yourself
from torch.utils.data import TensorDataset
Xt = torch.from_numpy(X)
yt = torch.from_numpy(y)
tensor_ds = TensorDataset(Xt, yt)
net.fit(tensor_ds, None)

Still too slow

You find your skorch code still to be slow despite trying all of these tips, and you made sure that the slowdown is indeed caused by skorch. What can you do now? In this case, please search our issue tracker for solutions or open a new issue. Provide as much context as possible and, if available, a minimal code example. We will try to help you figure out what the problem is.

FAQ

How do I apply L2 regularization?

To apply L2 regularization (aka weight decay), PyTorch supplies the weight_decay parameter, which must be supplied to the optimizer. To pass this variable in skorch, use the double-underscore notation for the optimizer:

net = NeuralNet(
    ...,
    optimizer__weight_decay=0.01,
)

How can I continue training my model?

By default, when you call fit() more than once, the training starts from zero instead of from where it was left. This is in line with sklearn’s behavior but not always desired. If you would like to continue training, use partial_fit() instead of fit(). Alternatively, there is the warm_start argument, which is False by default. Set it to True instead and you should be fine.

How do I shuffle my train batches?

skorch uses DataLoader from PyTorch under the hood. This class takes a couple of arguments, for instance shuffle. We therefore need to pass the shuffle argument to DataLoader, which we achieve by using the double-underscore notation (as known from sklearn):

net = NeuralNet(
    ...,
    iterator_train__shuffle=True,
)

Note that we have an iterator_train for the training data and an iterator_valid for validation and test data. In general, you only want to shuffle the train data, which is what the code above does.

How do I use sklearn GridSeachCV when my data is in a dictionary?

skorch supports dicts as input but sklearn doesn’t. To get around that, try to wrap your dictionary into a SliceDict. This is a data container that partly behaves like a dict, partly like an ndarray. For more details on how to do this, have a look at the coresponding data section in the notebook.

How do I use sklearn GridSeachCV when my data is in a dataset?

skorch supports datasets as input but sklearn doesn’t. If it’s possible, you should provide your data in a non-dataset format, e.g. as a numpy array or torch tensor, extracted from your original dataset.

Sometimes, this is not possible, e.g. when your data doesn’t fit into memory. To get around that, try to wrap your dataset into a SliceDataset. This is a data container that partly behaves like a dataset, partly like an ndarray. Further information can be found here: SliceDataset.

I want to use sample_weight, how can I do this?

Some scikit-learn models support to pass a sample_weight argument to fit calls as part of the fit_params. This allows you to give different samples different weights in the final loss calculation.

In general, skorch supports fit_params, but unfortunately just calling net.fit(X, y, sample_weight=sample_weight) is not enough, because the fit_params are not split into train and valid, and are not batched, resulting in a mismatch with the training batches.

Fortunately, skorch supports passing dictionaries as arguments, which are actually split into train and valid and then batched. Therefore, the best solution is to pass the sample_weight with X as a dictionary. Below, there is example code on how to achieve this:

X, y = get_data()
# put your X into a dict if not already a dict
X = {'data': X}
# add sample_weight to the X dict
X['sample_weight'] = sample_weight

class MyModule(nn.Module):
    ...
    def forward(self, data, sample_weight):
        # when X is a dict, its keys are passed as kwargs to forward, thus
        # our forward has to have the arguments 'data' and 'sample_weight';
        # usually, sample_weight can be ignored here
        ...

class MyNet(NeuralNet):
    def __init__(self, *args, criterion__reduce=False, **kwargs):
        # make sure to set reduce=False in your criterion, since we need the loss
        # for each sample so that it can be weighted
        super().__init__(*args, criterion__reduce=criterion__reduce, **kwargs)

    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
        sample_weight = skorch.utils.to_tensor(X['sample_weight'], device=self.device)
        loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced

net = MyNet(MyModule, ...)
net.fit(X, y)

I already split my data into training and validation sets, how can I use them?

If you have predefined training and validation datasets that are subclasses of PyTorch Dataset, you can use predefined_split() to wrap your validation dataset and pass it to NeuralNet’s train_split parameter:

from skorch.helper import predefined_split

net = NeuralNet(
    ...,
    train_split=predefined_split(valid_ds)
)
net.fit(train_ds)

If you split your data by using train_test_split(), you can create your own skorch Dataset, and then pass it to predefined_split():

from sklearn.model_selection import train_test_split
from skorch.helper import predefined_split
from skorch.dataset import Dataset

X_train, X_test, y_train, y_test = train_test_split(X, y)

valid_ds = Dataset(X_test, y_test)

net = NeuralNet(
    ...,
    train_split=predefined_split(valid_ds)
)

net.fit(X_train, y_train)

What happens when NeuralNet is passed an initialized Pytorch module?

When NeuralNet is passed an initialized Pytorch module, skorch will usually leave the module alone. In the following example, the resulting module will be trained for 20 epochs:

class MyModule(nn.Module):
    def __init__(self, hidden=10):
        ...

module = MyModule()
net1 = NeuralNet(module, max_epochs=10, ...)
net1.fit(X, y)

net2 = NeuralNet(module, max_epochs=10, ...)
net2.fit(X, y)

When the module is passed to the second NeuralNet, it will not be re-initialized and will keep its parameters from the first 10 epochs.

When the module parameters are set through keywords arguments, NeuralNet will re-initialized the module:

net = NeuralNet(module, module__hidden=10, ...)
net.fit(X, y)

Although it is possible to pass an initialized Pytorch module to NeuralNet, it is recommended to pass the module class instead:

net = NeuralNet(MyModule, ...)
net.fit(X, y)

In this case, fit() will always re-initialize the model and partial_fit() won’t after the network is initialized once.

How do I use a PyTorch Dataset with skorch?

skorch supports PyTorch’s Dataset as arguments to fit() or partial_fit(). We create a dataset by subclassing PyTorch’s Dataset:

import torch.utils.data

class RandomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.X = torch.randn(128, 10)
        self.Y = torch.randn(128, 10)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    def __len__(self):
        return 128

skorch expects the output of __getitem__ to be a tuple of two values. The RandomDataset can be passed directly to fit():

from skorch import NeuralNet
import torch.nn as nn

train_ds = RandomDataset()

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, X):
        return self.layer(X)

net = NeuralNet(MyModule, criterion=torch.nn.MSELoss)
net.fit(train_ds)

How can I deal with multiple return values from forward?

skorch supports modules that return multiple values. To do this, simply return a tuple of all values that you want to return from the forward method. However, this tuple will also be passed to the criterion. If the criterion cannot deal with multiple values, this will result in an error.

To remedy this, you need to either implement your own criterion that can deal with the output or you need to override get_loss() and handle the unpacking of the tuple.

To inspect all output values, you can use either the forward() method (eager) or the forward_iter() method (lazy).

For an example of how this works, have a look at this notebook.

How can I perform gradient accumulation with skorch?

There is no direct option to turn on gradient accumulation (at least for now). However, with a few modifications, you can implement gradient accumulation yourself:

ACC_STEPS = 2  # number of steps to accumulate before updating weights

class GradAccNet(NeuralNetClassifier):
    """Net that accumulates gradients"""
    def __init__(self, *args, acc_steps=ACC_STEPS, **kwargs):
        super().__init__(*args, **kwargs)
        self.acc_steps = acc_steps

    def get_loss(self, *args, **kwargs):
        loss = super().get_loss(*args, **kwargs)
        return loss / self.acc_steps  # normalize loss

    def train_step(self, Xi, yi, **fit_params):
        """Perform gradient accumulation

        Only optimize every nth batch.

        """
        # note that n_train_batches starts at 1 for each epoch
        n_train_batches = len(self.history[-1, 'batches'])
        step = self.train_step_single(Xi, yi, **fit_params)

        if n_train_batches % self.acc_steps == 0:
            self.optimizer_.step()
            self.optimizer_.zero_grad()
        return step

This is not a complete recipe. For example, if you optimize every 2nd step, and the number of training batches is uneven, you should make sure that there is an optimization step after the last batch of each epoch. However, this example can serve as a starting point to implement your own version gradient accumulation.

How can I dynamically set the input size of the PyTorch module based on the data?

Typically, it’s up to the user to determine the shape of the input data when defining the PyTorch module. This can sometimes be inconvenient, e.g. when the shape is only known at runtime. E.g., when using sklearn.feature_selection.VarianceThreshold, you cannot know the number of features in advance. The best solution would be to set the input size dynamically.

In most circumstances, this can be achieved with a few lines of code in skorch. Here is an example:

class InputShapeSetter(skorch.callbacks.Callback):
    def on_train_begin(self, net, X, y):
        net.set_params(module__input_dim=X.shape[1])


net = skorch.NeuralNetClassifier(
    ClassifierModule,
    callbacks=[InputShapeSetter()],
)

This assumes that your module accepts an argument called input_units, which determines the number of units of the input layer, and that the number of features can be determined by X.shape[1]. If those assumptions are not true for your case, adjust the code accordingly. A fully working example can be found on stackoverflow.

How do I implement a score method on the net that returns the loss?

Sometimes, it is useful to be able to compute the loss of a net from within skorch (e.g. when a net is part of an sklearn pipeline). The function skorch.scoring.loss_scoring() achieves this. Two examples are provided below. The first demonstrates how to use skorch.scoring.loss_scoring() as a function on a trained net object.

from skorch.scoring import loss_scoring

X = np.random.randn(250, 25).astype('float32')
y = (X.dot(np.ones(25)) > 0).astype(int)

module = nn.Sequential(
    nn.Linear(25, 25),
    nn.ReLU(),
    nn.Linear(25, 2),
    nn.Softmax(dim=1)
)
net = skorch.NeuralNetClassifier(module).fit(X, y)
print(loss_scoring(net, X, y))

The second example shows how to sub-class skorch.classifier.NeuralNetClassifier to implement a score method. In this example, the score method returns the negative of the loss value, because we want sklearn.model_selection.GridSearchCV to return the run with least loss and sklearn.model_selection.GridSearchCV searches for the run with the greatest score.

class ScoredNet(skorch.NeuralNetClassifier):
    def score(self, X, y=None):
        loss_value = loss_scoring(self, X, y)
        return -loss_value

net = ScoredNet(module)
grid_searcher = GridSearchCV(
    net, {'lr': [1e-2, 1e-3], 'batch_size': [8, 16]},
)
grid_searcher.fit(X, y)
best_net = grid_searcher.best_estimator_
print(best_net.score(X, y))

API Reference

If you are looking for information on a specific function, class or method, this part of the documentation is for you.

skorch

skorch.callbacks

This module serves to elevate callbacks in submodules to the skorch.callback namespace. Remember to define __all__ in each submodule.

class skorch.callbacks.BatchScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]

Callback that performs generic scoring on batches.

This callback determines the score after each batch and stores it in the net’s history in the column given by name. At the end of the epoch, the average of the scores are determined and also stored in the history. Furthermore, it is determined whether this average score is the best score yet and that information is also stored in the history.

In contrast to EpochScoring, this callback determines the score for each batch and then averages the score at the end of the epoch. This can be disadvantageous for some scores if the batch size is small – e.g. area under the ROC will return incorrect scores in this case. Therefore, it is recommnded to use EpochScoring unless you really need the scores for each batch.

If y is None, the scoring function with signature (model, X, y) must be able to handle X as a Tensor and y=None.

Parameters:
scoring : None, str, or callable

If None, use the score method of the model. If str, it should be a valid sklearn metric (e.g. “f1_score”, “accuracy_score”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

lower_is_better : bool (default=True)

Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better.

on_train : bool (default=False)

Whether this should be called during train or validation.

name : str or None (default=None)

If not an explicit string, tries to infer the name from the scoring argument.

target_extractor : callable (default=to_numpy)

This is called on y before it is passed to scoring.

use_caching : bool (default=True)

Re-use the model’s prediction for computing the loss to calculate the score. Turning this off will result in an additional inference step for each batch.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net, X, y, training, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, X, y, **kwargs) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_avg_score  
get_params  
set_params  
on_batch_end(net, X, y, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.Callback[source]

Base class for callbacks.

All custom callbacks should inherit from this class. The subclass may override any of the on_... methods. It is, however, not necessary to override all of them, since it’s okay if they don’t have any effect.

Classes that inherit from this also gain the get_params and set_params method.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_begin(net, X=None, y=None, training=None, **kwargs)[source]

Called at the beginning of each batch.

on_batch_end(net, X=None, y=None, training=None, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the end of each epoch.

on_grad_computed(net, named_parameters, X=None, y=None, training=None, **kwargs)[source]

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, X=None, y=None, **kwargs)[source]

Called at the beginning of training.

on_train_end(net, X=None, y=None, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.Checkpoint(monitor='valid_loss_best', f_params='params.pt', f_optimizer='optimizer.pt', f_criterion='criterion.pt', f_history='history.json', f_pickle=None, fn_prefix='', dirname='', event_name='event_cp', sink=<function noop>, **kwargs)[source]

Save the model during training if the given metric improved.

This callback works by default in conjunction with the validation scoring callback since it creates a valid_loss_best value in the history which the callback uses to determine if this epoch is save-worthy.

You can also specify your own metric to monitor or supply a callback that dynamically evaluates whether the model should be saved in this epoch.

Some or all of the following can be saved:

  • model parameters (see f_params parameter);
  • optimizer state (see f_optimizer parameter);
  • criterion state (see f_criterion parameter);
  • training history (see f_history parameter);
  • entire model object (see f_pickle parameter).

If you’ve created a custom module, e.g. net.mymodule_, you can save that as well by passing f_mymodule.

You can implement your own save protocol by subclassing Checkpoint and overriding save_model().

This callback writes a bool flag to the history column event_cp indicating whether a checkpoint was created or not.

Example:

>>> net = MyNet(callbacks=[Checkpoint()])
>>> net.fit(X, y)

Example using a custom monitor where models are saved only in epochs where the validation and the train losses are best:

>>> monitor = lambda net: all(net.history[-1, (
...     'train_loss_best', 'valid_loss_best')])
>>> net = MyNet(callbacks=[Checkpoint(monitor=monitor)])
>>> net.fit(X, y)
Parameters:
monitor : str, function, None

Value of the history to monitor or callback that determines whether this epoch should lead to a checkpoint. The callback takes the network instance as parameter.

In case monitor is set to None, the callback will save the network at every epoch.

Note: If you supply a lambda expression as monitor, you cannot pickle the wrapper anymore as lambdas cannot be pickled. You can mitigate this problem by using importable functions instead.

f_params : file-like object, str, None (default=’params.pt’)

File path to the file or file-like object where the model parameters should be saved. Pass None to disable saving model parameters.

If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are net, last_epoch and last_batch. Example to include last epoch number in file name:

>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
f_optimizer : file-like object, str, None (default=’optimizer.pt’)

File path to the file or file-like object where the optimizer state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_criterion : file-like object, str, None (default=’criterion.pt’)

File path to the file or file-like object where the criterion state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_history : file-like object, str, None (default=’history.json’)

File path to the file or file-like object where the model training history should be saved. Pass None to disable saving history.

f_pickle : file-like object, str, None (default=None)

File path to the file or file-like object where the entire model object should be pickled. Pass None to disable pickling.

Supports the same format specifiers as f_params.

fn_prefix: str (default=’’)

Prefix for filenames. If f_params, f_optimizer, f_history, or f_pickle are strings, they will be prefixed by fn_prefix.

dirname: str (default=’’)

Directory where files are stored.

event_name: str, (default=’event_cp’)

Name of event to be placed in history when checkpoint is triggered. Pass None to disable placing events in history.

sink : callable (default=noop)

The target that the information about created checkpoints is sent to. This can be a logger or print function (to send to stdout). By default the output is discarded.

Attributes:
f_history_

Methods

get_formatted_files(net) Returns a dictionary of formatted filenames
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
save_model(net) Save the model.
get_params  
set_params  
get_formatted_files(net)[source]

Returns a dictionary of formatted filenames

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

save_model(net)[source]

Save the model.

This function saves some or all of the following:

  • model parameters;
  • optimizer state;
  • criterion state;
  • training history;
  • custom modules;
  • entire model object.
class skorch.callbacks.EarlyStopping(monitor='valid_loss', patience=5, threshold=0.0001, threshold_mode='rel', lower_is_better=True, sink=<built-in function print>)[source]

Callback for stopping training when scores don’t improve.

Stop training early if a specified monitor metric did not improve in patience number of epochs by at least threshold.

Parameters:
monitor : str (default=’valid_loss’)

Value of the history to monitor to decide whether to stop training or not. The value is expected to be double and is commonly provided by scoring callbacks such as skorch.callbacks.EpochScoring.

lower_is_better : bool (default=True)

Whether lower scores should be considered better or worse.

patience : int (default=5)

Number of epochs to wait for improvement of the monitor value until the training process is stopped.

threshold : int (default=1e-4)

Ignore score improvements smaller than threshold.

threshold_mode : str (default=’rel’)

One of rel, abs. Decides whether the threshold value is interpreted in absolute terms or as a fraction of the best score so far (relative)

sink : callable (default=print)

The target that the information about early stopping is sent to. By default, the output is printed to stdout, but the sink could also be a logger or noop().

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, **kwargs) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_begin(net, **kwargs)[source]

Called at the beginning of training.

class skorch.callbacks.EpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]

Callback that performs generic scoring on predictions.

At the end of each epoch, this callback makes a prediction on train or validation data, determines the score for that prediction and whether it is the best yet, and stores the result in the net’s history.

In case you already computed a score value for each batch you can omit the score computation step by return the value from the history. For example:

>>> def my_score(net, X=None, y=None):
...     losses = net.history[-1, 'batches', :, 'my_score']
...     batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size']
...     return np.average(losses, weights=batch_sizes)
>>> net = MyNet(callbacks=[
...     ('my_score', Scoring(my_score, name='my_score'))

If you fit with a custom dataset, this callback should work as expected as long as use_caching=True which enables the collection of y values from the dataset. If you decide to disable the caching of predictions and y values, you need to write your own scoring function that is able to deal with the dataset and returns a scalar, for example:

>>> def ds_accuracy(net, ds, y=None):
...     # assume ds yields (X, y), e.g. torchvision.datasets.MNIST
...     y_true = [y for _, y in ds]
...     y_pred = net.predict(ds)
...     return sklearn.metrics.accuracy_score(y_true, y_pred)
>>> net = MyNet(callbacks=[
...     EpochScoring(ds_accuracy, use_caching=False)])
>>> ds = torchvision.datasets.MNIST(root=mnist_path)
>>> net.fit(ds)
Parameters:
scoring : None, str, or callable (default=None)

If None, use the score method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

lower_is_better : bool (default=True)

Whether lower scores should be considered better or worse.

on_train : bool (default=False)

Whether this should be called during train or validation data.

name : str or None (default=None)

If not an explicit string, tries to infer the name from the scoring argument.

target_extractor : callable (default=to_numpy)

This is called on y before it is passed to scoring.

use_caching : bool (default=True)

Collect labels and predictions (y_true and y_pred) over the course of one epoch and use the cached values for computing the score. The cached values are shared between all EpochScoring instances. Disabling this will result in an additional inference step for each epoch and an inability to use arbitrary datasets as input (since we don’t know how to extract y_true from an arbitrary dataset).

Methods

get_test_data(dataset_train, dataset_valid) Return data needed to perform scoring.
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net, y, y_pred, training, **kwargs) Called at the end of each batch.
on_epoch_begin(net, dataset_train, …) Called at the beginning of each epoch.
on_epoch_end(net, dataset_train, …) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, X, y, **kwargs) Called at the beginning of training.
on_train_end(*args, **kwargs) Called at the end of training.
get_params  
set_params  
get_test_data(dataset_train, dataset_valid)[source]

Return data needed to perform scoring.

This is a convenience method that handles picking of train/valid, different types of input data, use of cache, etc. for you.

Parameters:
dataset_train

Incoming training data or dataset.

dataset_valid

Incoming validation data or dataset.

Returns:
X_test

Input data used for making the prediction.

y_test

Target ground truth. If caching was enabled, return cached y_test.

y_pred : list

The predicted targets. If caching was disabled, the list is empty. If caching was enabled, the list contains the batches of the predictions. It may thus be necessary to concatenate the output before working with it: y_pred = np.concatenate(y_pred)

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, y, y_pred, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train, dataset_valid, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]

Called at the end of each epoch.

on_train_end(*args, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.EpochTimer(**kwargs)[source]

Measures the duration of each epoch and writes it to the history with the name dur.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
on_epoch_begin(net, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.Freezer(*args, **kwargs)[source]

Freeze matching parameters at the start of the first epoch. You may specify a specific point in time (either by epoch number or using a callable) when the parameters are frozen using the at parameter.

See ParamMapper for details.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
class skorch.callbacks.GradientNormClipping(gradient_clip_value=None, gradient_clip_norm_type=2)[source]

Clips gradient norm of a module’s parameters.

The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

See torch.nn.utils.clip_grad_norm_() for more information.

Parameters:
gradient_clip_value : float (default=None)

If not None, clip the norm of all model parameter gradients to this value. The type of the norm is determined by the gradient_clip_norm_type parameter and defaults to L2.

gradient_clip_norm_type : float (default=2)

Norm to use when gradient clipping is active. The default is to use L2-norm. Can be ‘inf’ for infinity norm.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(_, named_parameters, **kwargs) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
on_grad_computed(_, named_parameters, **kwargs)[source]

Called once per batch after gradients have been computed but before an update step was performed.

class skorch.callbacks.Initializer(*args, **kwargs)[source]

Apply any function on matching parameters in the first epoch.

Examples

Use Initializer to initialize all dense layer weights with values sampled from an uniform distribution on the beginning of the first epoch:

>>> init_fn = partial(torch.nn.init.uniform_, a=-1e-3, b=1e-3)
>>> cb = Initializer('dense*.weight', fn=init_fn)
>>> net = Net(myModule, callbacks=[cb])

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
class skorch.callbacks.LRScheduler(policy='WarmRestartLR', monitor='train_loss', event_name='event_lr', step_every='epoch', **kwargs)[source]

Callback that sets the learning rate of each parameter group according to some policy.

Parameters:
policy : str or _LRScheduler class (default=’WarmRestartLR’)

Learning rate policy name or scheduler to be used.

monitor : str or callable (default=None)

Value of the history to monitor or function/callable. In the latter case, the callable receives the net instance as argument and is expected to return the score (float) used to determine the learning rate adjustment.

event_name: str, (default=’event_lr’)

Name of event to be placed in history when the scheduler takes a step. Pass None to disable placing events in history. Note: This feature works only for pytorch version >=1.4

step_every: str, (default=’epoch’)
Value for when to apply the learning scheduler step. Can be either ‘batch’

or ‘epoch’.

kwargs

Additional arguments passed to the lr scheduler.

Attributes:
kwargs

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net, training, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, **kwargs) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
simulate(steps, initial_lr) Simulates the learning rate scheduler.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_begin(net, **kwargs)[source]

Called at the beginning of training.

simulate(steps, initial_lr)[source]

Simulates the learning rate scheduler.

Parameters:
steps: int

Number of steps to simulate

initial_lr: float

Initial learning rate

Returns:
lrs: numpy ndarray

Simulated learning rates

class skorch.callbacks.LoadInitState(checkpoint)[source]

Loads the model, optimizer, and history from a checkpoint into a NeuralNet when training begins.

Parameters:
checkpoint: :class:`.Checkpoint`

Checkpoint to get filenames from.

Examples

Consider running the following example multiple times:

>>> cp = Checkpoint(monitor='valid_loss_best')
>>> load_state = LoadInitState(cp)
>>> net = NeuralNet(..., callbacks=[cp, load_state])
>>> net.fit(X, y)

On the first run, the Checkpoint saves the model, optimizer, and history when the validation loss is minimized. During the first run, there are no files on disk, thus LoadInitState will not load anything. When running the example a second time, LoadInitState will load the best model from the first run and continue training from there.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_train_begin(net, X=None, y=None, **kwargs)[source]

Called at the beginning of training.

class skorch.callbacks.NeptuneLogger(experiment, log_on_batch_end=False, close_after_train=True, keys_ignored=None)[source]

Logs results from history to Neptune

Neptune is a lightweight experiment tracking tool. You can read more about it here: https://neptune.ai

Use this callback to automatically log all interesting values from your net’s history to Neptune.

The best way to log additional information is to log directly to the experiment object or subclass the on_* methods.

To monitor resource consumption install psutil

>>> pip install psutil

You can view example experiment logs here: https://ui.neptune.ai/o/shared/org/skorch-integration/e/SKOR-13/charts

Parameters:
experiment : neptune.experiments.Experiment

Instantiated Experiment class.

log_on_batch_end : bool (default=False)

Whether to log loss and other metrics on batch level.

close_after_train : bool (default=True)

Whether to close the Experiment object once training finishes. Set this parameter to False if you want to continue logging to the same Experiment or if you use it as a context manager.

keys_ignored : str or list of str (default=None)

Key or list of keys that should not be logged to Neptune. Note that in addition to the keys provided by the user, keys such as those starting with ‘event_’ or ending on ‘_best’ are ignored by default.

Examples

>>> # Install neptune
>>> pip install neptune-client
>>> # Create a neptune experiment object
>>> import neptune
...
... # We are using api token for an anonymous user.
... # For your projects use the token associated with your neptune.ai account
>>> neptune.init(api_token='ANONYMOUS',
...              project_qualified_name='shared/skorch-integration')
...
... experiment = neptune.create_experiment(
...                        name='skorch-basic-example',
...                        params={'max_epochs': 20,
...                                'lr': 0.01},
...                        upload_source_files=['skorch_example.py'])
>>> # Create a neptune_logger callback
>>> neptune_logger = NeptuneLogger(experiment, close_after_train=False)
>>> # Pass a logger to net callbacks argument
>>> net = NeuralNetClassifier(
...           ClassifierModule,
...           max_epochs=20,
...           lr=0.01,
...           callbacks=[neptune_logger])
>>> # Log additional metrics after training has finished
>>> from sklearn.metrics import roc_auc_score
... y_pred = net.predict_proba(X)
... auc = roc_auc_score(y, y_pred[:, 1])
...
... neptune_logger.experiment.log_metric('roc_auc_score', auc)
>>> # log charts like ROC curve
... from scikitplot.metrics import plot_roc
... import matplotlib.pyplot as plt
...
... fig, ax = plt.subplots(figsize=(16, 12))
... plot_roc(y, y_pred, ax=ax)
... neptune_logger.experiment.log_image('roc_curve', fig)
>>> # log net object after training
... net.save_params(f_params='basic_model.pkl')
... neptune_logger.experiment.log_artifact('basic_model.pkl')
>>> # close experiment
... neptune_logger.experiment.stop()
Attributes:
first_batch_ : bool

Helper attribute that is set to True at initialization and changes to False on first batch end. Can be used when we want to log things exactly once.

.. _Neptune: https://www.neptune.ai

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Automatically log values from the last history step.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net, **kwargs) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Automatically log values from the last history step.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.ParamMapper(patterns, fn=<function noop>, at=1, schedule=None)[source]

Map arbitrary functions over module parameters filtered by pattern matching.

In the simplest case the function is only applied once at the beginning of a given epoch (at on_epoch_begin) but more complex execution schemes (e.g. periodic application) are possible using at and scheduler.

Parameters:
patterns : str or callable or list

The pattern(s) to match parameter names against. Patterns are UNIX globbing patterns as understood by fnmatch(). Patterns can also be callables which will get called with the parameter name and are regarded as a match when the callable returns a truthy value.

This parameter also supports lists of str or callables so that one ParamMapper can match a group of parameters.

Example: 'linear*.weight' or ['linear0.*', 'linear1.bias'] or lambda name: name.startswith('linear').

fn : function

The function to apply to each parameter separately.

at : int or callable

In case you specify an integer it represents the epoch number the function fn is applied to the parameters, in case at is a function it will receive net as parameter and the function is applied to the parameter once at returns True.

schedule : callable or None

If specified this callable supersedes the static at/fn combination by dynamically returning the function that is applied on the matched parameters. This way you can, for example, create a schedule that periodically freezes and unfreezes layers.

The callable’s signature is schedule(net: NeuralNet) -> callable.

Notes

When starting the training process after saving and loading a model, ParamMapper might re-initialize parts of your model when the history is not saved along with the model. To avoid this, in case you use ParamMapper (or subclasses, e.g. Initializer) and want to save your model make sure to either (a) use pickle, (b) save and load the history or (c) remove the parameter mapper callbacks before continuing training.

Examples

Initialize a layer on first epoch before the first training step:

>>> init = partial(torch.nn.init.uniform_, a=0, b=1)
>>> cb = ParamMapper('linear*.weight', at=1, fn=init)
>>> net = Net(myModule, callbacks=[cb])

Reset layer initialization if train loss reaches a certain value (e.g. re-initialize on overfit):

>>> at = lambda net: net.history[-1, 'train_loss'] < 0.1
>>> init = partial(torch.nn.init.uniform_, a=0, b=1)
>>> cb = ParamMapper('linear0.weight', at=at, fn=init)
>>> net = Net(myModule, callbacks=[cb])

Periodically freeze and unfreeze all embedding layers:

>>> def my_sched(net):
...    if len(net.history) % 2 == 0:
...        return skorch.utils.freeze_parameter
...    else:
...        return skorch.utils.unfreeze_parameter
>>> cb = ParamMapper('embedding*.weight', schedule=my_sched)
>>> net = Net(myModule, callbacks=[cb])

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_begin(net, **kwargs)[source]

Called at the beginning of each epoch.

class skorch.callbacks.PassthroughScoring(name, lower_is_better=True, on_train=False)[source]

Creates scores on epoch level based on batch level scores

This callback doesn’t calculate any new scores but instead passes through a score that was created on the batch level. Based on that score, an average across the batch is created (honoring the batch size) and recorded in the history for the given epoch.

Use this callback when there already is a score calculated on the batch level. If that score has yet to be calculated, use BatchScoring instead.

Parameters:
name : str

Name of the score recorded on a batch level in the history.

lower_is_better : bool (default=True)

Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better.

on_train : bool (default=False)

Whether this should be called during train or validation.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_avg_score  
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.PrintLog(keys_ignored=None, sink=<built-in function print>, tablefmt='simple', floatfmt='.4f', stralign='right')[source]

Print useful information from the model’s history as a table.

By default, PrintLog prints everything from the history except for 'batches'.

To determine the best loss, PrintLog looks for keys that end on '_best' and associates them with the corresponding loss. E.g., 'train_loss_best' will be matched with 'train_loss'. The skorch.callbacks.EpochScoring callback takes care of creating those entries, which is why PrintLog works best in conjunction with that callback.

PrintLog treats keys with the 'event_' prefix in a special way. They are assumed to contain information about occasionally occuring events. The False or None entries (indicating that an event did not occur) are not printed, resulting in empty cells in the table, and True entries are printed with + symbol. PrintLog groups all event columns together and pushes them to the right, just before the 'dur' column.

Note: PrintLog will not result in good outputs if the number of columns varies between epochs, e.g. if the valid loss is only present on every other epoch.

Parameters:
keys_ignored : str or list of str (default=None)

Key or list of keys that should not be part of the printed table. Note that in addition to the keys provided by the user, keys such as those starting with ‘event_’ or ending on ‘_best’ are ignored by default.

sink : callable (default=print)

The target that the output string is sent to. By default, the output is printed to stdout, but the sink could also be a logger, etc.

tablefmt : str (default=’simple’)

The format of the table. See the documentation of the tabulate package for more detail. Can be ‘plain’, ‘grid’, ‘pipe’, ‘html’, ‘latex’, among others.

floatfmt : str (default=’.4f’)

The number formatting. See the documentation of the tabulate package for more details.

stralign : str (default=’right’)

The alignment of columns with strings. Can be ‘left’, ‘center’, ‘right’, or None (disable alignment). Default is ‘right’ (to be consistent with numerical columns).

Methods

format_row(row, key, color) For a given row from the table, format it (i.e.
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
table  
format_row(row, key, color)[source]

For a given row from the table, format it (i.e. floating points and color if applicable).

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.ProgressBar(batches_per_epoch='auto', detect_notebook=True, postfix_keys=None)[source]

Display a progress bar for each epoch.

The progress bar includes elapsed and estimated remaining time for the current epoch, the number of batches processed, and other user-defined metrics. The progress bar is erased once the epoch is completed.

ProgressBar needs to know the total number of batches per epoch in order to display a meaningful progress bar. By default, this number is determined automatically using the dataset length and the batch size. If this heuristic does not work for some reason, you may either specify the number of batches explicitly or let the ProgressBar count the actual number of batches in the previous epoch.

For jupyter notebooks a non-ASCII progress bar can be printed instead. To use this feature, you need to have ipywidgets installed.

Parameters:
batches_per_epoch : int, str (default=’auto’)

Either a concrete number or a string specifying the method used to determine the number of batches per epoch automatically. 'auto' means that the number is computed from the length of the dataset and the batch size. 'count' means that the number is determined by counting the batches in the previous epoch. Note that this will leave you without a progress bar at the first epoch.

detect_notebook : bool (default=True)

If enabled, the progress bar determines if its current environment is a jupyter notebook and switches to a non-ASCII progress bar.

postfix_keys : list of str (default=[‘train_loss’, ‘valid_loss’])

You can use this list to specify additional info displayed in the progress bar such as metrics and losses. A prerequisite to this is that these values are residing in the history on batch level already, i.e. they must be accessible via

>>> net.history[-1, 'batches', -1, key]

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
in_ipynb  
set_params  
on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.TrainEndCheckpoint(f_params='params.pt', f_optimizer='optimizer.pt', f_criterion='criterion.pt', f_history='history.json', f_pickle=None, fn_prefix='train_end_', dirname='', sink=<function noop>, **kwargs)[source]

Saves the model parameters, optimizer state, and history at the end of training. The default fn_prefix is ‘train_end_’.

Parameters:
f_params : file-like object, str, None (default=’params.pt’)

File path to the file or file-like object where the model parameters should be saved. Pass None to disable saving model parameters.

If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are net, last_epoch and last_batch. Example to include last epoch number in file name:

>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
f_optimizer : file-like object, str, None (default=’optimizer.pt’)

File path to the file or file-like object where the optimizer state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_criterion : file-like object, str, None (default=’criterion.pt’)

File path to the file or file-like object where the criterion state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_history : file-like object, str, None (default=’history.json’)

File path to the file or file-like object where the model training history should be saved. Pass None to disable saving history.

f_pickle : file-like object, str, None (default=None)

File path to the file or file-like object where the entire model object should be pickled. Pass None to disable pickling.

Supports the same format specifiers as f_params.

fn_prefix: str (default=’train_end_’)

Prefix for filenames. If f_params, f_optimizer, f_history, or f_pickle are strings, they will be prefixed by fn_prefix.

dirname: str (default=’’)

Directory where files are stored.

sink : callable (default=noop)

The target that the information about created checkpoints is sent to. This can be a logger or print function (to send to stdout). By default the output is discarded.

Examples

Consider running the following example multiple times:

>>> train_end_cp = TrainEndCheckpoint(dirname='exp1')
>>> load_state = LoadInitState(train_end_cp)
>>> net = NeuralNet(..., callbacks=[train_end_cp, load_state])
>>> net.fit(X, y)

After the first run, model parameters, optimizer state, and history are saved into a directory named exp1. On the next run, LoadInitState will load the state from the first run and continue training.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net, **kwargs) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.TensorBoard(writer, close_after_train=True, keys_ignored=None, key_mapper=<function rename_tensorboard_key>)[source]

Logs results from history to TensorBoard

“TensorBoard provides the visualization and tooling needed for machine learning experimentation” (tensorboard_)

Use this callback to automatically log all interesting values from your net’s history to tensorboard after each epoch.

The best way to log additional information is to subclass this callback and add your code to one of the on_* methods.

Parameters:
writer : torch.utils.tensorboard.writer.SummaryWriter

Instantiated SummaryWriter class.

close_after_train : bool (default=True)

Whether to close the SummaryWriter object once training finishes. Set this parameter to False if you want to continue logging with the same writer or if you use it as a context manager.

keys_ignored : str or list of str (default=None)

Key or list of keys that should not be logged to tensorboard. Note that in addition to the keys provided by the user, keys such as those starting with ‘event_’ or ending on ‘_best’ are ignored by default.

key_mapper : callable or function (default=rename_tensorboard_key)

This function maps a key name from the history to a tag in tensorboard. This is useful because tensorboard can automatically group similar tags if their names start with the same prefix, followed by a forward slash. By default, this callback will prefix all keys that start with “train” or “valid” with the “Loss/” prefix.

.. _tensorboard: https://www.tensorflow.org/tensorboard/

Examples

>>> # Example to log the bias parameter as a histogram
>>> def extract_bias(module):
...     return module.hidden.bias
>>> class MyTensorBoard(TensorBoard):
...     def on_epoch_end(self, net, **kwargs):
...         bias = extract_bias(net.module_)
...         epoch = net.history[-1, 'epoch']
...         self.writer.add_histogram('bias', bias, global_step=epoch)
...         super().on_epoch_end(net, **kwargs)  # call super last

Methods

add_scalar_maybe(history, key, tag[, …]) Add a scalar value from the history to TensorBoard
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Automatically log values from the last history step.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net, **kwargs) Called at the end of training.
get_params  
set_params  
add_scalar_maybe(history, key, tag, global_step=None)[source]

Add a scalar value from the history to TensorBoard

Will catch errors like missing keys or wrong value types.

Parameters:
history : skorch.History

History object saved as attribute on the neural net.

key : str

Key of the desired value in the history.

tag : str

Name of the tag used in TensorBoard.

global_step : int or None

Global step value to record.

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Automatically log values from the last history step.

on_train_end(net, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.SacredLogger(experiment, log_on_batch_end=False, log_on_epoch_end=True, batch_suffix=None, epoch_suffix=None, keys_ignored=None)[source]

Logs results from history to Sacred.

Sacred is a tool to help you configure, organize, log and reproduce experiments. Developed at IDSIA. See https://github.com/IDSIA/sacred.

Use this callback to automatically log all interesting values from your net’s history to Sacred.

If you want to log additional information, you can simply add it to History. See the documentation on Callbacks, and Scoring for more information. Alternatively you can subclass this callback and extend the on_* methods.

To use this logger, you first have to install Sacred:

$ pip install sacred

You might also install pymongo to use a mongodb backend. See the upstream_ documentation for more details. Once you have installed it, you can set up a simple experiment and pass this Logger as a callback to your skorch estimator:

# contents of sacred-experiment.py >>> import numpy as np >>> from sacred import Experiment >>> from sklearn.datasets import make_classification >>> from skorch.callbacks.logging import SacredLogger >>> from skorch.callbacks.scoring import EpochScoring >>> from skorch import NeuralNetClassifier >>> from skorch.toy import make_classifier

>>> ex = Experiment()
>>> @ex.config
>>> def my_config():
...     max_epochs = 20
...     lr = 0.01
>>> X, y = make_classification()
>>> X, y = X.astype(np.float32), y.astype(np.int64)
>>> @ex.automain
>>> def main(_run, max_epochs, lr):
...     # Take care to add additional scoring callbacks *before* the logger.
...     net = NeuralNetClassifier(
...         make_classifier(),
...         max_epochs=max_epochs,
...         lr=0.01,
...         callbacks=[EpochScoring("f1"), SacredLogger(_run)]
...     )
...     # now fit your estimator to your data
...     net.fit(X, y)

Then call this from the command line, e.g. like this: python sacred-script.py with max_epochs=15

You can also change other options on the command line and optionally specify a backend.

Parameters:
experiment : sacred.Experiment

Instantiated Experiment class.

log_on_batch_end : bool (default=False)

Whether to log loss and other metrics on batch level.

log_on_epoch_end : bool (default=True)

Whether to log loss and other metrics on epoch level.

batch_suffix : str (default=None)

A string that will be appended to all logged keys. By default (if set to None) “_batch” is used if batch and epoch logging are both enabled and no suffix is used otherwise.

epoch_suffix : str (default=None)

A string that will be appended to all logged keys. By default (if set to None) “_epoch” is used if batch and epoch logging are both enabled and no suffix is used otherwise.

keys_ignored : str or list of str (default=None)

Key or list of keys that should not be logged to Sacred. Note that in addition to the keys provided by the user, keys such as those starting with ‘event_’ or ending on ‘_best’ are ignored by default.

.. _upstream: https://github.com/IDSIA/sacred#installing

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Automatically log values from the last history step.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Automatically log values from the last history step.

class skorch.callbacks.Unfreezer(*args, **kwargs)[source]

Inverse operation of Freezer.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
class skorch.callbacks.WandbLogger(wandb_run, save_model=True, keys_ignored=None)[source]

Logs best model and metrics to Weights & Biases

Use this callback to automatically log best trained model, all metrics from your net’s history, model topology and computer resources to Weights & Biases after each epoch.

Every file saved in wandb_run.dir is automatically logged to W&B servers.

See example run

Parameters:
wandb_run : wandb.wandb_run.Run

wandb Run used to log data.

save_model : bool (default=True)

Whether to save a checkpoint of the best model and upload it to your Run on W&B servers.

keys_ignored : str or list of str (default=None)

Key or list of keys that should not be logged to tensorboard. Note that in addition to the keys provided by the user, keys such as those starting with ‘event_’ or ending on ‘_best’ are ignored by default.

Examples

>>> # Install wandb
... pip install wandb
>>> import wandb
>>> from skorch.callbacks import WandbLogger
>>> # Create a wandb Run
... wandb_run = wandb.init()
>>> # Alternative: Create a wandb Run without having a W&B account
... wandb_run = wandb.init(anonymous="allow)
>>> # Log hyper-parameters (optional)
... wandb_run.config.update({"learning rate": 1e-3, "batch size": 32})
>>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)])
>>> net.fit(X, y)

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, X, y, training]) Called at the beginning of each batch.
on_batch_end(net[, X, y, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Log values from the last history step and save best model
on_grad_computed(net, named_parameters[, X, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, **kwargs) Log model topology and add a hook for gradients
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Log values from the last history step and save best model

on_train_begin(net, **kwargs)[source]

Log model topology and add a hook for gradients

class skorch.callbacks.WarmRestartLR(optimizer, min_lr=1e-06, max_lr=0.05, base_period=10, period_mult=2, last_epoch=-1)[source]

Stochastic Gradient Descent with Warm Restarts (SGDR) scheduler.

This scheduler sets the learning rate of each parameter group according to stochastic gradient descent with warm restarts (SGDR) policy. This policy simulates periodic warm restarts of SGD, where in each restart the learning rate is initialize to some value and is scheduled to decrease.

Parameters:
optimizer : torch.optimizer.Optimizer instance.

Optimizer algorithm.

min_lr : float or list of float (default=1e-6)

Minimum allowed learning rate during each period for all param groups (float) or each group (list).

max_lr : float or list of float (default=0.05)

Maximum allowed learning rate during each period for all param groups (float) or each group (list).

base_period : int (default=10)

Initial restart period to be multiplied at each restart.

period_mult : int (default=2)

Multiplicative factor to increase the period between restarts.

last_epoch : int (default=-1)

The index of the last valid epoch.

References

[1]Ilya Loshchilov and Frank Hutter, 2017, “Stochastic Gradient Descent with Warm Restarts,”. “ICLR” https://arxiv.org/pdf/1608.03983.pdf

Methods

get_last_lr() Return last computed learning rate by current scheduler.
load_state_dict(state_dict) Loads the schedulers state.
print_lr(is_verbose, group, lr[, epoch]) Display the current learning rate.
state_dict() Returns the state of the scheduler as a dict.
get_lr  
step  

skorch.classifier

NeuralNet subclasses for classification tasks.

class skorch.classifier.NeuralNetBinaryClassifier(module, *args, criterion=<class 'torch.nn.modules.loss.BCEWithLogitsLoss'>, train_split=<skorch.dataset.CVSplit object>, threshold=0.5, **kwargs)[source]

NeuralNet for binary classification tasks

Use this specifically if you have a binary classification task, with input data X and target y. y must be 1d.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class, default=torch.nn.BCEWithLogitsLoss)

Binary cross entropy loss with logits. Note that the module should return the logit of probabilities with shape (batch_size, ).

threshold : float (default=0.5)

Probabilities above this threshold is classified as 1. threshold is used by predict and predict_proba for classification.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None, “disable”, or list of Callback instances (default=None)

Which callbacks to enable. There are three possible values:

If callbacks=None, only use default callbacks, those returned by get_default_callbacks.

If callbacks="disable", disable all callbacks, i.e. do not run any of the callbacks.

If callbacks is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance of Callback.

Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)

The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for CrossEntropyLoss and sigmoid for BCEWithLogitsLoss). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.

In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.

This can be useful, e.g., when predict_proba() should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and the predict_nonlinearity transforms this output into probabilities.

The nonlinearity is applied only when calling predict() or predict_proba() but not anywhere else – notably, the loss is unaffected by this nonlinearity.

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attribute prefixes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

check_data(X, y)
check_is_fitted([attributes]) Checks whether the net is initialized
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X, y, **fit_params) See NeuralNet.fit.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training]) Return the loss for this batch.
get_params_for(prefix) Collect and return init parameters for an attribute.
get_params_for_optimizer(prefix, …) Collect and return init parameters for an optimizer.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform an inference step
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer([triggered_directly]) Initialize the model optimizer.
load_params([f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Return the output of the module’s forward method as a numpy array.
run_single_epoch(dataset, training, prefix, …) Compute a single epoch of train or validation.
save_params([f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
score(X, y[, sample_weight]) Return the mean accuracy on the given test data and labels.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
get_default_callbacks  
get_params  
initialize_virtual_params  
on_batch_end  
on_grad_computed  
fit(X, y, **fit_params)[source]

See NeuralNet.fit.

In contrast to NeuralNet.fit, y is non-optional to avoid mistakenly forgetting about y. However, y can be set to None in case it is derived dynamically from X.

infer(x, **fit_params)[source]

Perform an inference step

The first output of the module must be a single array that has either shape (n,) or shape (n, 1). In the latter case, the output will be reshaped to become 1-dim.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_pred : numpy ndarray
class skorch.classifier.NeuralNetClassifier(module, *args, criterion=<class 'torch.nn.modules.loss.NLLLoss'>, train_split=<skorch.dataset.CVSplit object>, classes=None, **kwargs)[source]

NeuralNet for classification tasks

Use this specifically if you have a standard classification task, with input data X and target y.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class, default=torch.nn.NLLLoss)

Negative log likelihood loss. Note that the module should return probabilities, the log is applied during get_loss.

classes : None or list (default=None)

If None, the classes_ attribute will be inferred from the y data passed to fit. If a non-empty list is passed, that list will be returned as classes_. If the initial skorch behavior should be restored, i.e. raising an AttributeError, pass an empty list.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None, “disable”, or list of Callback instances (default=None)

Which callbacks to enable. There are three possible values:

If callbacks=None, only use default callbacks, those returned by get_default_callbacks.

If callbacks="disable", disable all callbacks, i.e. do not run any of the callbacks.

If callbacks is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance of Callback.

Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)

The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for CrossEntropyLoss and sigmoid for BCEWithLogitsLoss). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.

In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.

This can be useful, e.g., when predict_proba() should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and the predict_nonlinearity transforms this output into probabilities.

The nonlinearity is applied only when calling predict() or predict_proba() but not anywhere else – notably, the loss is unaffected by this nonlinearity.

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attribute prefixes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

classes_ : array, shape (n_classes, )

A list of class labels known to the classifier.

Methods

check_data(X, y)
check_is_fitted([attributes]) Checks whether the net is initialized
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X, y, **fit_params) See NeuralNet.fit.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true, *args, **kwargs) Return the loss for this batch.
get_params_for(prefix) Collect and return init parameters for an attribute.
get_params_for_optimizer(prefix, …) Collect and return init parameters for an optimizer.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer([triggered_directly]) Initialize the model optimizer.
load_params([f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Where applicable, return probability estimates for samples.
run_single_epoch(dataset, training, prefix, …) Compute a single epoch of train or validation.
save_params([f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
score(X, y[, sample_weight]) Return the mean accuracy on the given test data and labels.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
get_default_callbacks  
get_params  
initialize_virtual_params  
on_batch_end  
on_grad_computed  
fit(X, y, **fit_params)[source]

See NeuralNet.fit.

In contrast to NeuralNet.fit, y is non-optional to avoid mistakenly forgetting about y. However, y can be set to None in case it is derived dynamically from X.

get_loss(y_pred, y_true, *args, **kwargs)[source]

Return the loss for this batch.

Parameters:
y_pred : torch tensor

Predicted target values

y_true : torch tensor

True target values.

X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether train mode should be used or not.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_pred : numpy ndarray
predict_proba(X)[source]

Where applicable, return probability estimates for samples.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_proba : numpy ndarray

skorch.dataset

Contains custom skorch Dataset and CVSplit.

class skorch.dataset.CVSplit(cv=5, stratified=False, random_state=None)[source]

Class that performs the internal train/valid split on a dataset.

The cv argument here works similarly to the regular sklearn cv parameter in, e.g., GridSearchCV. However, instead of cycling through all splits, only one fixed split (the first one) is used. To get a full cycle through the splits, don’t use NeuralNet’s internal validation but instead the corresponding sklearn functions (e.g. cross_val_score).

We additionally support a float, similar to sklearn’s train_test_split.

Parameters:
cv : int, float, cross-validation generator or an iterable, optional

(Refer sklearn’s User Guide for cross_validation for the various cross-validation strategies that can be used here.)

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 3-fold cross validation,
  • integer, to specify the number of folds in a (Stratified)KFold,
  • float, to represent the proportion of the dataset to include in the validation split.
  • An object to be used as a cross-validation generator.
  • An iterable yielding train, validation splits.
stratified : bool (default=False)

Whether the split should be stratified. Only works if y is either binary or multiclass classification.

random_state : int, RandomState instance, or None (default=None)

Control the random state in case that (Stratified)ShuffleSplit is used (which is when a float is passed to cv). For more information, look at the sklearn documentation of (Stratified)ShuffleSplit.

Methods

__call__(dataset[, y, groups]) Call self as a function.
check_cv(y) Resolve which cross validation strategy is used.
check_cv(y)[source]

Resolve which cross validation strategy is used.

class skorch.dataset.Dataset(X, y=None, length=None)[source]

General dataset wrapper that can be used in conjunction with PyTorch DataLoader.

The dataset will always yield a tuple of two values, the first from the data (X) and the second from the target (y). However, the target is allowed to be None. In that case, Dataset will currently return a dummy tensor, since DataLoader does not work with Nones.

Dataset currently works with the following data types:

  • numpy arrays
  • PyTorch Tensors
  • scipy sparse CSR matrices
  • pandas NDFrame
  • a dictionary of the former three
  • a list/tuple of the former three
Parameters:
X : see above

Everything pertaining to the input data.

y : see above or None (default=None)

Everything pertaining to the target, if there is anything.

length : int or None (default=None)

If not None, determines the length (len) of the data. Should usually be left at None, in which case the length is determined by the data itself.

Methods

transform(X, y) Additional transformations on X and y.
transform(X, y)[source]

Additional transformations on X and y.

By default, they are cast to PyTorch Tensors. Override this if you want a different behavior.

Note: If you use this in conjuction with PyTorch DataLoader, the latter will call the dataset for each row separately, which means that the incoming X and y each are single rows.

skorch.dataset.unpack_data(data)[source]

Unpack data returned by the net’s iterator into a 2-tuple.

If the wrong number of items is returned, raise a helpful error message.

skorch.dataset.uses_placeholder_y(ds)[source]

If ds is a skorch.dataset.Dataset or a skorch.dataset.Dataset nested inside a torch.utils.data.Subset and uses y as a placeholder, return True.

skorch.exceptions

Contains skorch-specific exceptions and warnings.

exception skorch.exceptions.DeviceWarning[source]

A problem with a device (e.g. CUDA) was detected.

exception skorch.exceptions.NotInitializedError[source]

Module is not initialized, please call the .initialize method or train the model by calling .fit(...).

exception skorch.exceptions.SkorchException[source]

Base skorch exception.

exception skorch.exceptions.SkorchWarning[source]

Base skorch warning.

skorch.helper

Helper functions and classes for users.

They should not be used in skorch directly.

class skorch.helper.DataFrameTransformer(treat_int_as_categorical=False, float_dtype=<class 'numpy.float32'>, int_dtype=<class 'numpy.int64'>)[source]

Transform a DataFrame into a dict useful for working with skorch.

Transforms cardinal data to floats and categorical data to vectors of ints so that they can be embedded.

Although skorch can deal with pandas DataFrames, the default behavior is often not very useful. Use this transformer to transform the DataFrame into a dict with all float columns concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.

Your module must have a matching signature for this to work. It must accept an argument X for all cardinal values. Additionally, for all categorical values, it must accept an argument with the same name as the corresponding column (see example below). If you need help with the required signature, use the describe_signature method of this class and pass it your data.

You can choose whether you want to treat int columns the same as float columns (default) or as categorical values.

To one-hot encode categorical features, initialize their corresponding embedding layers using the identity matrix.

Parameters:
treat_int_as_categorical : bool (default=False)

Whether to treat integers as categorical values or as cardinal values, i.e. the same as floats.

float_dtype : numpy dtype or None (default=np.float32)

The dtype to cast the cardinal values to. If None, don’t change them.

int_dtype : numpy dtype or None (default=np.int64)

The dtype to cast the categorical values to. If None, don’t change them. If you do this, it can happen that the categorical values will have different dtypes, reflecting the number of unique categories.

Notes

The value of X will always be 2-dimensional, even if it only contains 1 column.

Examples

>>> df = pd.DataFrame({
...     'col_floats': np.linspace(0, 1, 12),
...     'col_ints': [11, 11, 10] * 4,
...     'col_cats': ['a', 'b', 'a'] * 4,
... })
>>> # cast to category dtype to later learn embeddings
>>> df['col_cats'] = df['col_cats'].astype('category')
>>> y = np.asarray([0, 1, 0] * 4)
>>> class MyModule(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.reset_params()
>>>     def reset_params(self):
...         self.embedding = nn.Embedding(2, 10)
...         self.linear = nn.Linear(2, 10)
...         self.out = nn.Linear(20, 2)
...         self.nonlin = nn.Softmax(dim=-1)
>>>     def forward(self, X, col_cats):
...         # "X" contains the values from col_floats and col_ints
...         # "col_cats" contains the values from "col_cats"
...         X_lin = self.linear(X)
...         X_cat = self.embedding(col_cats)
...         X_concat = torch.cat((X_lin, X_cat), dim=1)
...         return self.nonlin(self.out(X_concat))
>>> net = NeuralNetClassifier(MyModule)
>>> pipe = Pipeline([
...     ('transform', DataFrameTransformer()),
...     ('net', net),
... ])
>>> pipe.fit(df, y)

Methods

describe_signature(df) Describe the signature required for the given data.
fit(df[, y])
fit_transform(X[, y]) Fit to data, then transform it.
get_params([deep]) Get parameters for this estimator.
set_params(**params) Set the parameters of this estimator.
transform(df) Transform DataFrame to become a dict that works well with skorch.
describe_signature(df)[source]

Describe the signature required for the given data.

Pass the DataFrame to receive a description of the signature required for the module’s forward method. The description consists of three parts:

1. The names of the arguments that the forward method needs. 2. The dtypes of the torch tensors passed to forward. 3. The number of input units that are required for the corresponding argument. For the float parameter, this is just the number of dimensions of the tensor. For categorical parameters, it is the number of unique elements.

Returns:
signature : dict

Returns a dict with each key corresponding to one key required for the forward method. The values are dictionaries of two elements. The key “dtype” describes the torch dtype of the resulting tensor, the key “input_units” describes the required number of input units.

pd = <module 'pandas' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/site-packages/pandas/__init__.py'>[source]
transform(df)[source]

Transform DataFrame to become a dict that works well with skorch.

Parameters:
df : pd.DataFrame

Incoming DataFrame.

Returns:
X_dict: dict

Dictionary with all floats concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.

class skorch.helper.SliceDataset(dataset, idx=0, indices=None)[source]

Helper class that wraps a torch dataset to make it work with sklearn.

Sometimes, sklearn will touch the input data, e.g. when splitting the data for a grid search. This will fail when the input data is a torch dataset. To prevent this, use this wrapper class for your dataset.

Note: This class will only return the X value by default (i.e. the first value returned by indexing the original dataset). Sklearn, and hence skorch, always require 2 values, X and y. Therefore, you still need to provide the y data separately.

Note: This class behaves similarly to a PyTorch Subset when it is indexed by a slice or numpy array: It will return another SliceDataset that references the subset instead of the actual values. Only when it is indexed by an int does it return the actual values. The reason for this is to avoid loading all data into memory when sklearn, for instance, creates a train/validation split on the dataset. Data will only be loaded in batches during the fit loop.

Parameters:
dataset : torch.utils.data.Dataset

A valid torch dataset.

idx : int (default=0)

Indicates which element of the dataset should be returned. Typically, the dataset returns both X and y values. SliceDataset can only return 1 value. If you want to get X, choose idx=0 (default), if you want y, choose idx=1.

indices : list, np.ndarray, or None (default=None)

If you only want to return a subset of the dataset, indicate which subset that is by passing this argument. Typically, this can be left to be None, which returns all the data. See also Subset.

Examples

>>> X = MyCustomDataset()
>>> search = GridSearchCV(net, params, ...)
>>> search.fit(X, y)  # raises error
>>> ds = SliceDataset(X)
>>> search.fit(ds, y)  # works
Attributes:
shape

Methods

count(value)
index(value, [start, [stop]]) Raises ValueError if the value is not present.
transform(data) Additional transformations on data.
transform(data)[source]

Additional transformations on data.

Note: If you use this in conjuction with PyTorch DataLoader, the latter will call the dataset for each row separately, which means that the incoming data is a single rows.

class skorch.helper.SliceDict(**kwargs)[source]

Wrapper for Python dict that makes it sliceable across values.

Use this if your input data is a dictionary and you have problems with sklearn not being able to slice it. Wrap your dict with SliceDict and it should usually work.

Note:

  • SliceDict cannot be indexed by integers, if you want one row, say row 3, use [3:4].
  • SliceDict accepts numpy arrays and torch tensors as values.

Examples

>>> X = {'key0': val0, 'key1': val1}
>>> search = GridSearchCV(net, params, ...)
>>> search.fit(X, y)  # raises error
>>> Xs = SliceDict(key0=val0, key1=val1)  # or Xs = SliceDict(**X)
>>> search.fit(Xs, y)  # works
Attributes:
shape

Methods

clear()
copy()
fromkeys(*args, **kwargs) fromkeys method makes no sense with SliceDict and is thus not supported.
get($self, key[, default]) Return the value for key if key is in the dictionary, else default.
items()
keys()
pop(k[,d]) If key is not found, d is returned if given, otherwise KeyError is raised
popitem() 2-tuple; but raise KeyError if D is empty.
setdefault($self, key[, default]) Insert key with a value of default if key is not in the dictionary.
update([E, ]**F) If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
values()
copy() → a shallow copy of D[source]
fromkeys(*args, **kwargs)[source]

fromkeys method makes no sense with SliceDict and is thus not supported.

update([E, ]**F) → None. Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

skorch.helper.predefined_split(dataset)[source]

Uses dataset for validiation in NeuralNet.

Parameters:
dataset: torch Dataset

Validiation dataset

Examples

>>> valid_ds = skorch.Dataset(X, y)
>>> net = NeuralNet(..., train_split=predefined_split(valid_ds))

skorch.history

Contains history class and helper functions.

class skorch.history.History[source]

History contains the information about the training history of a NeuralNet, facilitating some of the more common tasks that are occur during training.

When you want to log certain information during training (say, a particular score or the norm of the gradients), you should write them to the net’s history object.

It is basically a list of dicts for each epoch, that, again, contains a list of dicts for each batch. For convenience, it has enhanced slicing notation and some methods to write new items.

To access items from history, you may pass a tuple of up to four items:

  1. Slices along the epochs.
  2. Selects columns from history epochs, may be a single one or a tuple of column names.
  3. Slices along the batches.
  4. Selects columns from history batchs, may be a single one or a tuple of column names.

You may use a combination of the four items.

If you select columns that are not present in all epochs/batches, only those epochs/batches are chosen that contain said columns. If this set is empty, a KeyError is raised.

Examples

>>> # ACCESSING ITEMS
>>> # history of a fitted neural net
>>> history = net.history
>>> # get current epoch, a dict
>>> history[-1]
>>> # get train losses from all epochs, a list of floats
>>> history[:, 'train_loss']
>>> # get train and valid losses from all epochs, a list of tuples
>>> history[:, ('train_loss', 'valid_loss')]
>>> # get current batches, a list of dicts
>>> history[-1, 'batches']
>>> # get latest batch, a dict
>>> history[-1, 'batches', -1]
>>> # get train losses from current batch, a list of floats
>>> history[-1, 'batches', :, 'train_loss']
>>> # get train and valid losses from current batch, a list of tuples
>>> history[-1, 'batches', :, ('train_loss', 'valid_loss')]
>>> # WRITING ITEMS
>>> # add new epoch row
>>> history.new_epoch()
>>> # add an entry to current epoch
>>> history.record('my-score', 123)
>>> # add a batch row to the current epoch
>>> history.new_batch()
>>> # add an entry to the current batch
>>> history.record_batch('my-batch-score', 456)
>>> # overwrite entry of current batch
>>> history.record_batch('my-batch-score', 789)

Methods

append($self, object, /) Append object to the end of the list.
clear($self, /) Remove all items from list.
copy($self, /) Return a shallow copy of the list.
count($self, value, /) Return number of occurrences of value.
extend($self, iterable, /) Extend list by appending elements from the iterable.
from_file(f) Load the history of a NeuralNet from a json file.
index($self, value[, start, stop]) Return first index of value.
insert($self, index, object, /) Insert object before index.
new_batch() Register a new batch row for the current epoch.
new_epoch() Register a new epoch row.
pop($self[, index]) Remove and return item at index (default last).
record(attr, value) Add a new value to the given column for the current epoch.
record_batch(attr, value) Add a new value to the given column for the current batch.
remove($self, value, /) Remove first occurrence of value.
reverse($self, /) Reverse IN PLACE.
sort($self, /, *[, key, reverse]) Stable sort IN PLACE.
to_file(f) Saves the history as a json file.
to_list() Return history object as a list.
classmethod from_file(f)[source]

Load the history of a NeuralNet from a json file.

Parameters:
f : file-like object or str
new_batch()[source]

Register a new batch row for the current epoch.

new_epoch()[source]

Register a new epoch row.

record(attr, value)[source]

Add a new value to the given column for the current epoch.

record_batch(attr, value)[source]

Add a new value to the given column for the current batch.

to_file(f)[source]

Saves the history as a json file. In order to use this feature, the history must only contain JSON encodable Python data structures. Numpy and PyTorch types should not be in the history.

Parameters:
f : file-like object or str
to_list()[source]

Return history object as a list.

skorch.net

Neural net classes.

class skorch.net.NeuralNet(module, criterion, optimizer=<class 'torch.optim.sgd.SGD'>, lr=0.01, max_epochs=10, batch_size=128, iterator_train=<class 'torch.utils.data.dataloader.DataLoader'>, iterator_valid=<class 'torch.utils.data.dataloader.DataLoader'>, dataset=<class 'skorch.dataset.Dataset'>, train_split=<skorch.dataset.CVSplit object>, callbacks=None, predict_nonlinearity='auto', warm_start=False, verbose=1, device='cpu', **kwargs)[source]

NeuralNet base class.

The base class covers more generic cases. Depending on your use case, you might want to use NeuralNetClassifier or NeuralNetRegressor.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class)

The uninitialized criterion (loss) used to optimize the module.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None, “disable”, or list of Callback instances (default=None)

Which callbacks to enable. There are three possible values:

If callbacks=None, only use default callbacks, those returned by get_default_callbacks.

If callbacks="disable", disable all callbacks, i.e. do not run any of the callbacks.

If callbacks is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance of Callback.

Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)

The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for CrossEntropyLoss and sigmoid for BCEWithLogitsLoss). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.

In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.

This can be useful, e.g., when predict_proba() should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and the predict_nonlinearity transforms this output into probabilities.

The nonlinearity is applied only when calling predict() or predict_proba() but not anywhere else – notably, the loss is unaffected by this nonlinearity.

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attribute prefixes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

check_is_fitted([attributes]) Checks whether the net is initialized
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X[, y]) Initialize and fit the module.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training]) Return the loss for this batch.
get_params_for(prefix) Collect and return init parameters for an attribute.
get_params_for_optimizer(prefix, …) Collect and return init parameters for an optimizer.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer([triggered_directly]) Initialize the model optimizer.
load_params([f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Return the output of the module’s forward method as a numpy array.
run_single_epoch(dataset, training, prefix, …) Compute a single epoch of train or validation.
save_params([f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
check_data  
get_default_callbacks  
get_params  
initialize_virtual_params  
on_batch_end  
on_grad_computed  
check_is_fitted(attributes=None, *args, **kwargs)[source]

Checks whether the net is initialized

Parameters:
attributes : iterable of str or None (default=None)

All the attributes that are strictly required of a fitted net. By default, this is the module_ attribute.

Other arguments as in
``sklearn.utils.validation.check_is_fitted``.
Raises:
skorch.exceptions.NotInitializedError

When the given attributes are not present.

evaluation_step(Xi, training=False)[source]

Perform a forward step to produce the output used for prediction and scoring.

Therefore the module is set to evaluation mode by default beforehand which can be overridden to re-enable features like dropout by setting training=True.

fit(X, y=None, **fit_params)[source]

Initialize and fit the module.

If the module was already initialized, by calling fit, the module will be re-initialized (unless warm_start is True).

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

fit_loop(X, y=None, epochs=None, **fit_params)[source]

The proper fit loop.

Contains the logic of what actually happens during the fit loop.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

epochs : int or None (default=None)

If int, train for this number of epochs; if None, use self.max_epochs.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

forward(X, training=False, device='cpu')[source]

Gather and concatenate the output from forward call with input data.

The outputs from self.module_.forward are gathered on the compute device specified by device and then concatenated using PyTorch cat(). If multiple outputs are returned by self.module_.forward, each one of them must be able to be concatenated this way.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether to set the module to train mode or not.

device : string (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Returns:
y_infer : torch tensor

The result from the forward step.

forward_iter(X, training=False, device='cpu')[source]

Yield outputs of module forward calls on each batch of data. The storage device of the yielded tensors is determined by the device parameter.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether to set the module to train mode or not.

device : string (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Yields:
yp : torch tensor

Result from a forward call on an individual batch.

get_dataset(X, y=None)[source]

Get a dataset that contains the input data and is passed to the iterator.

Override this if you want to initialize your dataset differently.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

Returns:
dataset

The initialized dataset.

get_iterator(dataset, training=False)[source]

Get an iterator that allows to loop over the batches of the given data.

If self.iterator_train__batch_size and/or self.iterator_test__batch_size are not set, use self.batch_size instead.

Parameters:
dataset : torch Dataset (default=skorch.dataset.Dataset)

Usually, self.dataset, initialized with the corresponding data, is passed to get_iterator.

training : bool (default=False)

Whether to use iterator_train or iterator_test.

Returns:
iterator

An instantiated iterator that allows to loop over the mini-batches.

get_loss(y_pred, y_true, X=None, training=False)[source]

Return the loss for this batch.

Parameters:
y_pred : torch tensor

Predicted target values

y_true : torch tensor

True target values.

X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether train mode should be used or not.

get_params_for(prefix)[source]

Collect and return init parameters for an attribute.

Attributes could be, for instance, pytorch modules, criteria, or data loaders (for optimizers, use get_params_for_optimizer() instead). Use the returned arguments to initialize the given attribute like this:

# inside initialize_module method
kwargs = self.get_params_for('module')
self.module_ = self.module(**kwargs)

Proceed analogously for the criterion etc.

The reason to use this method is so that it’s possible to change the init parameters with set_params(), which in turn makes grid search and other similar things work.

Note that in general, as a user, you never have to deal with this method because initialize_module() etc. are already taking care of this. You only need to deal with this if you override initialize_module() (or similar methods) because you have some custom code that requires it.

Parameters:
prefix : str

The name of the attribute whose arguments should be returned. E.g. for the module, it should be 'module'.

Returns:
kwargs : dict

Keyword arguments to be used as init parameters.

get_params_for_optimizer(prefix, named_parameters)[source]

Collect and return init parameters for an optimizer.

Parse kwargs configuration for the optimizer identified by the given prefix. Supports param group assignment using wildcards:

optimizer__lr=0.05,
optimizer__param_groups=[
    ('rnn*.period', {'lr': 0.3, 'momentum': 0}),
    ('rnn0', {'lr': 0.1}),
]

Generally, use this method like this:

# inside initialize_optimizer method
named_params = self.module_.named_parameters()
pgroups, kwargs = self.get_params_for_optimizer('optimizer', named_params)
if 'lr' not in kwargs:
    kwargs['lr'] = self.lr
self.optimizer_ = self.optimizer(*pgroups, **kwargs)

The reason to use this method is so that it’s possible to change the init parameters with set_params(), which in turn makes grid search and other similar things work.

Note that in general, as a user, you never have to deal with this method because initialize_optimizer() is already taking care of this. You only need to deal with this if you override initialize_optimizer() because you have some custom code that requires it.

Parameters:
prefix : str

The name of the optimizer whose arguments should be returned. Typically, this should just be 'optimizer'. There can be exceptions, however, e.g. if you want to use more than one optimizer.

named_parameters : iterator

Iterator over the parameters of the module that is intended to be optimized. It’s the return value of my_module.named_parameters().

Returns:
args : tuple

All positional arguments for this optimizer (right now only one, the parameter groups).

kwargs : dict

All other parameters for this optimizer, e.g. the learning rate.

get_split_datasets(X, y=None, **fit_params)[source]

Get internal train and validation datasets.

The validation dataset can be None if self.train_split is set to None; then internal validation will be skipped.

Override this if you want to change how the net splits incoming data into train and validation part.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_params : dict

Additional parameters passed to the self.train_split call.

Returns:
dataset_train

The initialized training dataset.

dataset_valid

The initialized validation dataset or None

get_train_step_accumulator()[source]

Return the train step accumulator.

By default, the accumulator stores and retrieves the first value from the optimizer call. Most optimizers make only one call, so first value is at the same time the only value.

In case of some optimizers, e.g. LBFGS, train_step_calc_gradient is called multiple times, as the loss function is evaluated multiple times per optimizer call. If you don’t want to return the first value in that case, override this method to return your custom accumulator.

infer(x, **fit_params)[source]

Perform a single inference step on a batch of data.

Parameters:
x : input data

A batch of the input data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

initialize()[source]

Initializes all components of the NeuralNet and returns self.

initialize_callbacks()[source]

Initializes all callbacks and save the result in the callbacks_ attribute.

Both default_callbacks and callbacks are used (in that order). Callbacks may either be initialized or not, and if they don’t have a name, the name is inferred from the class name. The initialize method is called on all callbacks.

The final result will be a list of tuples, where each tuple consists of a name and an initialized callback. If names are not unique, a ValueError is raised.

initialize_criterion()[source]

Initializes the criterion.

initialize_history()[source]

Initializes the history.

initialize_module()[source]

Initializes the module.

Note that if the module has learned parameters, those will be reset.

initialize_optimizer(triggered_directly=True)[source]

Initialize the model optimizer. If self.optimizer__lr is not set, use self.lr instead.

Parameters:
triggered_directly : bool (default=True)

Only relevant when optimizer is re-initialized. Initialization of the optimizer can be triggered directly (e.g. when lr was changed) or indirectly (e.g. when the module was re-initialized). If and only if the former happens, the user should receive a message informing them about the parameters that caused the re-initialization.

load_params(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, checkpoint=None, **kwargs)[source]

Loads the the module’s parameters, history, and optimizer, not the whole object.

To save and load the whole object, use pickle.

f_params, f_optimizer, etc. uses PyTorch’s load().

If you’ve created a custom module, e.g. net.mymodule_, you can save that as well by passing f_mymodule.

Parameters:
f_params : file-like object, str, None (default=None)

Path of module parameters. Pass None to not load.

f_optimizer : file-like object, str, None (default=None)

Path of optimizer. Pass None to not load.

f_criterion : file-like object, str, None (default=None)

Path of criterion. Pass None to not save

f_history : file-like object, str, None (default=None)

Path to history. Pass None to not load.

checkpoint : Checkpoint, None (default=None)

Checkpoint to load params from. If a checkpoint and a f_* path is passed in, the f_* will be loaded. Pass None to not load.

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
>>>                    f_optimizer='optimizer.pkl',
>>>                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
>>>                   f_optimizer='optimizer.pkl',
>>>                   f_history='history.json')
notify(method_name, **cb_kwargs)[source]

Call the callback method specified in method_name with parameters specified in cb_kwargs.

Method names can be one of: * on_train_begin * on_train_end * on_epoch_begin * on_epoch_end * on_batch_begin * on_batch_end

partial_fit(X, y=None, classes=None, **fit_params)[source]

Fit the module.

If the module is initialized, it is not re-initialized, which means that this method should be used if you want to continue training a model (warm start).

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

classes : array, sahpe (n_classes,)

Solely for sklearn compatibility, currently unused.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_pred : numpy ndarray
predict_proba(X)[source]

Return the output of the module’s forward method as a numpy array.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_proba : numpy ndarray
run_single_epoch(dataset, training, prefix, step_fn, **fit_params)[source]

Compute a single epoch of train or validation.

Parameters:
dataset : torch Dataset

The initialized dataset to loop over.

training : bool

Whether to set the module to train mode or not.

prefix : str

Prefix to use when saving to the history.

step_fn : callable

Function to call for each batch.

**fit_params : dict

Additional parameters passed to the step_fn.

save_params(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, **kwargs)[source]

Saves the module’s parameters, history, and optimizer, not the whole object.

To save the whole object, use pickle. This is necessary when you need additional learned attributes on the net, e.g. the classes_ attribute on skorch.classifier.NeuralNetClassifier.

f_params, f_optimizer, etc. use PyTorch’s save().

If you’ve created a custom module, e.g. net.mymodule_, you can save that as well by passing f_mymodule.

Parameters:
f_params : file-like object, str, None (default=None)

Path of module parameters. Pass None to not save

f_optimizer : file-like object, str, None (default=None)

Path of optimizer. Pass None to not save

f_criterion : file-like object, str, None (default=None)

Path of criterion. Pass None to not save

f_history : file-like object, str, None (default=None)

Path to history. Pass None to not save

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
...                    f_optimizer='optimizer.pkl',
...                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
...                   f_optimizer='optimizer.pkl',
...                   f_history='history.json')
set_params(**kwargs)[source]

Set the parameters of this class.

Valid parameter keys can be listed with get_params().

Returns:
self
train_step(Xi, yi, **fit_params)[source]

Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.

Loss function callable as required by some optimizers (and accepted by all of them): https://pytorch.org/docs/master/optim.html#optimizer-step-closure

The module is set to be in train mode (e.g. dropout is applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the train_split call.

train_step_single(Xi, yi, **fit_params)[source]

Compute y_pred, loss value, and update net’s gradients.

The module is set to be in train mode (e.g. dropout is applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

validation_step(Xi, yi, **fit_params)[source]

Perform a forward step using batched data and return the resulting loss.

The module is set to be in evaluation mode (e.g. dropout is not applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

skorch.regressor

NeuralNet subclasses for regression tasks.

class skorch.regressor.NeuralNetRegressor(module, *args, criterion=<class 'torch.nn.modules.loss.MSELoss'>, **kwargs)[source]

NeuralNet for regression tasks

Use this specifically if you have a standard regression task, with input data X and target y. y must be 2d.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class, default=torch.nn.MSELoss)

Mean squared error loss.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None, “disable”, or list of Callback instances (default=None)

Which callbacks to enable. There are three possible values:

If callbacks=None, only use default callbacks, those returned by get_default_callbacks.

If callbacks="disable", disable all callbacks, i.e. do not run any of the callbacks.

If callbacks is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance of Callback.

Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)

The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for CrossEntropyLoss and sigmoid for BCEWithLogitsLoss). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.

In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.

This can be useful, e.g., when predict_proba() should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and the predict_nonlinearity transforms this output into probabilities.

The nonlinearity is applied only when calling predict() or predict_proba() but not anywhere else – notably, the loss is unaffected by this nonlinearity.

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attribute prefixes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

check_data(X, y)
check_is_fitted([attributes]) Checks whether the net is initialized
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X, y, **fit_params) See NeuralNet.fit.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training]) Return the loss for this batch.
get_params_for(prefix) Collect and return init parameters for an attribute.
get_params_for_optimizer(prefix, …) Collect and return init parameters for an optimizer.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer([triggered_directly]) Initialize the model optimizer.
load_params([f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Return the output of the module’s forward method as a numpy array.
run_single_epoch(dataset, training, prefix, …) Compute a single epoch of train or validation.
save_params([f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
score(X, y[, sample_weight]) Return the coefficient of determination \(R^2\) of the prediction.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
get_default_callbacks  
get_params  
initialize_virtual_params  
on_batch_end  
on_grad_computed  
fit(X, y, **fit_params)[source]

See NeuralNet.fit.

In contrast to NeuralNet.fit, y is non-optional to avoid mistakenly forgetting about y. However, y can be set to None in case it is derived dynamically from X.

skorch.scoring

skorch.scoring.loss_scoring(net: skorch.net.NeuralNet, X, y=None, sample_weight=None)[source]

Calculate score using the criterion of the net

Use the exact same logic as during model training to calculate the score.

This function can be used to implement the score method for a NeuralNet through sub-classing. This is useful, for example, when combining skorch models with sklearn objects that rely on the model’s score method. For example:

>>> class ScoredNet(skorch.NeuralNetClassifier):
...     def score(self, X, y=None):
...         return loss_scoring(self, X, y)
Parameters:
net : skorch.NeuralNet

A fitted Skorch NeuralNet object.

X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

sample_weight : array-like of shape (n_samples,)

Sample weights.

Returns:
loss_value : float32 or np.ndarray

Return type depends on net.criterion_.reduction, and will be a float if reduction is 'sum' or 'mean'. If reduction is 'none' then this function returns a np.ndarray object.

skorch.toy

Contains toy functions and classes for quick prototyping and testing.

class skorch.toy.MLPModule(input_units=20, output_units=2, hidden_units=10, num_hidden=1, nonlin=ReLU(), output_nonlin=None, dropout=0, squeeze_output=False)[source]

A simple multi-layer perceptron module.

This can be adapted for usage in different contexts, e.g. binary and multi-class classification, regression, etc.

Parameters:
input_units : int (default=20)

Number of input units.

output_units : int (default=2)

Number of output units.

hidden_units : int (default=10)

Number of units in hidden layers.

num_hidden : int (default=1)

Number of hidden layers.

nonlin : torch.nn.Module instance (default=torch.nn.ReLU())

Non-linearity to apply after hidden layers.

output_nonlin : torch.nn.Module instance or None (default=None)

Non-linearity to apply after last layer, if any.

dropout : float (default=0)

Dropout rate. Dropout is applied between layers.

squeeze_output : bool (default=False)

Whether to squeeze output. Squeezing can be helpful if you wish your output to be 1-dimensional (e.g. for NeuralNetBinaryClassifier).

Methods

add_module(name, module) Adds a child module to the current module.
apply(fn, None]) Applies fn recursively to every submodule (as returned by .children()) as well as self.
bfloat16() Casts all floating point parameters and buffers to bfloat16 datatype.
buffers(recurse) Returns an iterator over module buffers.
children() Returns an iterator over immediate children modules.
cpu() Moves all model parameters and buffers to the CPU.
cuda(device, torch.device, None] = None) Moves all model parameters and buffers to the GPU.
double() Casts all floating point parameters and buffers to double datatype.
eval() Sets the module in evaluation mode.
extra_repr() Set the extra representation of the module
float() Casts all floating point parameters and buffers to float datatype.
forward(X) Defines the computation performed at every call.
half() Casts all floating point parameters and buffers to half datatype.
load_state_dict(state_dict, torch.Tensor], …) Copies parameters and buffers from state_dict into this module and its descendants.
modules() Returns an iterator over all modules in the network.
named_buffers(prefix, recurse) Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children() Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules(memo, prefix) Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters(prefix, recurse) Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters(recurse) Returns an iterator over module parameters.
register_backward_hook(hook, …) Registers a backward hook on the module.
register_buffer(name, tensor, persistent) Adds a buffer to the module.
register_forward_hook(hook, None]) Registers a forward hook on the module.
register_forward_pre_hook(hook, None]) Registers a forward pre-hook on the module.
register_parameter(name, param) Adds a parameter to the module.
requires_grad_(requires_grad) Change if autograd should record operations on parameters in this module.
reset_params() (Re)set all parameters.
state_dict([destination, prefix, keep_vars]) Returns a dictionary containing a whole state of the module.
to(*args, **kwargs) Moves and/or casts the parameters and buffers.
train(mode) Sets the module in training mode.
type(dst_type, str]) Casts all parameters and buffers to dst_type.
zero_grad(set_to_none) Sets gradients of all model parameters to zero.
__call__  
share_memory  
forward(X)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_params()[source]

(Re)set all parameters.

skorch.toy.make_binary_classifier(squeeze_output=True, **kwargs)[source]

Return a multi-layer perceptron to be used with NeuralNetBinaryClassifier.

Parameters:
input_units : int (default=20)

Number of input units.

output_units : int (default=2)

Number of output units.

hidden_units : int (default=10)

Number of units in hidden layers.

num_hidden : int (default=1)

Number of hidden layers.

nonlin : torch.nn.Module instance (default=torch.nn.ReLU())

Non-linearity to apply after hidden layers.

dropout : float (default=0)

Dropout rate. Dropout is applied between layers.

skorch.toy.make_classifier(output_nonlin=Softmax(dim=-1), **kwargs)[source]

Return a multi-layer perceptron to be used with NeuralNetClassifier.

Parameters:
input_units : int (default=20)

Number of input units.

output_units : int (default=2)

Number of output units.

hidden_units : int (default=10)

Number of units in hidden layers.

num_hidden : int (default=1)

Number of hidden layers.

nonlin : torch.nn.Module instance (default=torch.nn.ReLU())

Non-linearity to apply after hidden layers.

dropout : float (default=0)

Dropout rate. Dropout is applied between layers.

skorch.toy.make_regressor(output_units=1, **kwargs)[source]

Return a multi-layer perceptron to be used with NeuralNetRegressor.

Parameters:
input_units : int (default=20)

Number of input units.

output_units : int (default=1)

Number of output units.

hidden_units : int (default=10)

Number of units in hidden layers.

num_hidden : int (default=1)

Number of hidden layers.

nonlin : torch.nn.Module instance (default=torch.nn.ReLU())

Non-linearity to apply after hidden layers.

dropout : float (default=0)

Dropout rate. Dropout is applied between layers.

skorch.utils

skorch utilities.

Should not have any dependency on other skorch packages.

class skorch.utils.Ansi[source]

An enumeration.

class skorch.utils.FirstStepAccumulator[source]

Store and retrieve the train step data.

This class simply stores the first step value and returns it.

For most uses, skorch.utils.FirstStepAccumulator is what you want, since the optimizer calls the train step exactly once. However, some optimizerss such as LBFGSs make more than one call. If in that case, you don’t want the first value to be returned (but instead, say, the last value), implement your own accumulator and make sure it is returned by NeuralNet.get_train_step_accumulator method.

Methods

get_step() Return the stored step.
store_step(step) Store the first step.
get_step()[source]

Return the stored step.

store_step(step)[source]

Store the first step.

class skorch.utils.TeeGenerator(gen)[source]

Stores a generator and calls tee on it to create new generators when TeeGenerator is iterated over to let you iterate over the given generator more than once.

skorch.utils.check_indexing(data)[source]

Perform a check how incoming data should be indexed and return an appropriate indexing function with signature f(data, index).

This is useful for determining upfront how data should be indexed instead of doing it repeatedly for each batch, thus saving some time.

skorch.utils.check_is_fitted(estimator, attributes, msg=None, all_or_any=<built-in function all>)[source]

Checks whether the net is initialized.

Note: This calls sklearn.utils.validation.check_is_fitted under the hood, using exactly the same arguments and logic. The only difference is that this function has an adapted error message and raises a skorch.exception.NotInitializedError instead of an sklearn.exceptions.NotFittedError.

skorch.utils.data_from_dataset(dataset, X_indexing=None, y_indexing=None)[source]

Try to access X and y attribute from dataset.

Also works when dataset is a subset.

Parameters:
dataset : skorch.dataset.Dataset or torch.utils.data.Subset

The incoming dataset should be a skorch.dataset.Dataset or a torch.utils.data.Subset of a skorch.dataset.Dataset.

X_indexing : function/callable or None (default=None)

If not None, use this function for indexing into the X data. If None, try to automatically determine how to index data.

y_indexing : function/callable or None (default=None)

If not None, use this function for indexing into the y data. If None, try to automatically determine how to index data.

skorch.utils.duplicate_items(*collections)[source]

Search for duplicate items in all collections.

Examples

>>> duplicate_items([1, 2], [3])
set()
>>> duplicate_items({1: 'a', 2: 'a'})
set()
>>> duplicate_items(['a', 'b', 'a'])
{'a'}
>>> duplicate_items([1, 2], {3: 'hi', 4: 'ha'}, (2, 3))
{2, 3}
skorch.utils.freeze_parameter(param)[source]

Convenience function to freeze a passed torch parameter. Used by skorch.callbacks.Freezer

skorch.utils.get_dim(y)[source]

Return the number of dimensions of a torch tensor or numpy array-like object.

skorch.utils.get_map_location(target_device, fallback_device='cpu')[source]

Determine the location to map loaded data (e.g., weights) for a given target device (e.g. ‘cuda’).

skorch.utils.is_skorch_dataset(ds)[source]

Checks if the supplied dataset is an instance of skorch.dataset.Dataset even when it is nested inside torch.util.data.Subset.

skorch.utils.multi_indexing(data, i, indexing=None)[source]

Perform indexing on multiple data structures.

Currently supported data types:

  • numpy arrays
  • torch tensors
  • pandas NDFrame
  • a dictionary of the former three
  • a list/tuple of the former three

i can be an integer or a slice.

Parameters:
data

Data of a type mentioned above.

i : int or slice

Slicing index.

indexing : function/callable or None (default=None)

If not None, use this function for indexing into the data. If None, try to automatically determine how to index data.

Examples

>>> multi_indexing(np.asarray([1, 2, 3]), 0)
1
>>> multi_indexing(np.asarray([1, 2, 3]), np.s_[:2])
array([1, 2])
>>> multi_indexing(torch.arange(0, 4), np.s_[1:3])
tensor([ 1.,  2.])
>>> multi_indexing([[1, 2, 3], [4, 5, 6]], np.s_[:2])
[[1, 2], [4, 5]]
>>> multi_indexing({'a': [1, 2, 3], 'b': [4, 5, 6]}, np.s_[-2:])
{'a': [2, 3], 'b': [5, 6]}
>>> multi_indexing(pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}), [1, 2])
   a  b
1  2  5
2  3  6
skorch.utils.noop(*args, **kwargs)[source]

No-op function that does nothing and returns None.

This is useful for defining scoring callbacks that do not need a target extractor.

skorch.utils.open_file_like(f, mode)[source]

Wrapper for opening a file

skorch.utils.params_for(prefix, kwargs)[source]

Extract parameters that belong to a given sklearn module prefix from kwargs. This is useful to obtain parameters that belong to a submodule.

Examples

>>> kwargs = {'encoder__a': 3, 'encoder__b': 4, 'decoder__a': 5}
>>> params_for('encoder', kwargs)
{'a': 3, 'b': 4}
skorch.utils.to_device(X, device)[source]

Generic function to modify the device type of the tensor(s) or module.

Parameters:
X : input data

Deals with X being a:

  • torch tensor
  • tuple of torch tensors
  • dict of torch tensors
  • PackSequence instance
  • torch.nn.Module
device : str, torch.device

The compute device to be used. If device=None, return the input unmodified

skorch.utils.to_numpy(X)[source]

Generic function to convert a pytorch tensor to numpy.

This function tries to unpack the tensor(s) from supported data structures (e.g., dicts, lists, etc.) but doesn’t go beyond.

Returns X when it already is a numpy array.

skorch.utils.to_tensor(X, device, accept_sparse=False)[source]

Turn input data to torch tensor.

Parameters:
X : input data
Handles the cases:
  • PackedSequence
  • numpy array
  • torch Tensor
  • scipy sparse CSR matrix
  • list or tuple of one of the former
  • dict with values of one of the former
device : str, torch.device

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.

accept_sparse : bool (default=False)

Whether to accept scipy sparse matrices as input. If False, passing a sparse matrix raises an error. If True, it is converted to a torch COO tensor.

Returns:
output : torch Tensor
skorch.utils.unfreeze_parameter(param)[source]

Convenience function to unfreeze a passed torch parameter. Used by skorch.callbacks.Unfreezer

Indices and tables