skorch documentation

A scikit-learn compatible neural network library that wraps PyTorch.

Introduction

The goal of skorch is to make it possible to use PyTorch with sklearn. This is achieved by providing a wrapper around PyTorch that has an sklearn interface. In that sense, skorch is the spiritual successor to nolearn, but instead of using Lasagne and Theano, it uses PyTorch.

skorch does not re-invent the wheel, instead getting as much out of your way as possible. If you are familiar with sklearn and PyTorch, you don’t have to learn any new concepts, and the syntax should be well known. (If you’re not familiar with those libraries, it is worth getting familiarized.)

Additionally, skorch abstracts away the training loop, making a lot of boilerplate code obsolete. A simple net.fit(X, y) is enough. Out of the box, skorch works with many types of data, be it PyTorch Tensors, NumPy arrays, Python dicts, and so on. However, if you have other data, extending skorch is easy to allow for that.

Overall, skorch aims at being as flexible as PyTorch while having a clean interface as sklearn.

User’s Guide

Installation

pip installation

To install with pip, run:

pip install -U skorch

We recommend to use a virtual environment for this.

From source

If you would like to use the must recent additions to skorch or help development, you should install skorch from source.

Using conda

You need a working conda installation. Get the correct miniconda for your system from here.

If you just want to use skorch, use:

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

git clone https://github.com/dnouri/skorch.git
cd skorch
conda env create
source activate skorch
# install pytorch version for your system (see below)
conda install --file requirements-dev.txt
python setup.py develop

py.test  # unit tests
pylint skorch  # static code checks
Using pip

If you just want to use skorch, use:

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
python setup.py install

If you want to help developing, run:

git clone https://github.com/dnouri/skorch.git
cd skorch
# create and activate a virtual environment
pip install -r requirements.txt
# install pytorch version for your system (see below)
pip install -r requirements-dev.txt
python setup.py develop

py.test  # unit tests
pylint skorch  # static code checks

pytorch

PyTorch is not covered by the dependencies, since the PyTorch version you need is dependent on your system. For installation instructions for PyTorch, visit the pytorch website.

In general, this should work:

# using conda:
conda install pytorch cuda80 -c soumith
# using pip
pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl

Quickstart

Training a model

Below, we define our own PyTorch Module and train it on a toy classification dataset using skorch NeuralNetClassifier:

import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X))
        return X


net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

In an sklearn Pipeline

Since NeuralNetClassifier provides an sklearn-compatible interface, it is possible to put it into an sklearn Pipeline:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

Whats next?

Please visit the Tutorials page to explore additional examples on using skorch!

Tutorials

The following are examples and notebooks on how to use skorch.

NeuralNet

Using NeuralNet

NeuralNet and the derived classes are the main touch point for the user. They wrap the PyTorch Module while providing an interface that should be familiar for sklearn users.

Define your Module the same way as you always do. Then pass it to NeuralNet, in conjunction with a PyTorch criterion. Finally, you can call fit() and predict(), as with an sklearn estimator. The finished code could look something like this:

class MyModule(torch.nn.Module):
    ...

net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)
net.fit(X, y)
y_pred = net.predict(X_valid)

Let’s see what skorch did for us here:

  • wraps the PyTorch Module in an sklearn interface
  • converts numpy.ndarrays to PyTorch Tensors
  • abstracts away the fit loop
  • takes care of batching the data

You therefore have a lot less boilerplate code, letting you focus on what matters. At the same time, skorch is very flexible and can be extended with ease, getting out of your way as much as possible.

Initialization

In general, when you instantiate the NeuralNet instance, only the given arguments are stored. They are stored exactly as you pass them to NeuralNet. For instance, the module will remain uninstantiated. This is to make sure that the arguments you pass are not touched afterwards, which makes it possible to clone the NeuralNet instance, for instance.

Only when the fit() or initialize() method are called, are the different attributes of the net, such as the module, initialized. An initialized attribute’s name always ends on an underscore; e.g., the initialized module is called module_. (This is the same nomenclature as sklearn uses.) Thefore, you always know which attributes you set and which ones were created by NeuralNet.

The only exception is the history attribute, which is not set by the user.

Most important arguments and methods

A complete explanation of all arguments and methods of NeuralNet are found in the skorch API documentation. Here we focus on the main ones.

module

This is where you pass your PyTorch Module. Ideally, it should not be instantiated. Instead, the init arguments for your module should be passed to NeuralNet with the module__ prefix. E.g., if your module takes the arguments num_units and dropout, the code would look like this:

class MyModule(torch.nn.Module):
    def __init__(self, num_units, dropout):
        ...

net = NeuralNet(
    module=MyModule,
    module__num_units=100,
    module__dropout=0.5,
    criterion=torch.nn.NLLLoss,
)

It is, however, also possible to pass an instantiated module, e.g. a PyTorch Sequential instance.

Note that skorch does not automatically apply any nonlinearities to the outputs (except internally when determining the PyTorch NLLLoss, see below). That means that if you have a classification task, you should make sure that the final output nonlinearity is a softmax. Otherwise, when you call predict_proba(), you won’t get actual probabilities.

criterion

This should be a PyTorch (-compatible) criterion.

When you use the NeuralNetClassifier, the criterion is set to PyTorch NLLLoss by default. Furthermore, if you don’t change it loss to another criterion, NeuralNetClassifier assumes that the module returns probabilities and will automatically apply a logarithm on them (which is what NLLLoss expects).

For NeuralNetRegressor, the default criterion is PyTorch MSELoss.

After initializing the NeuralNet, the initialized criterion will stored in the criterion_ attribute.

optimizer

This should be a PyTorch optimizer, e.g. SGD. After initializing the NeuralNet, the initialized optimizer will stored in the optimizer_ attribute. During initialization you can define param groups, for example to set different learning rates for certain parameters. The parameters are selected by name with support for wildcards (globbing):

optimizer__param_groups=[
    ('embedding.*', {'lr': 0.0}),
    ('linear0.bias', {'lr': 1}),
]
lr

The learning rate. This argument exists for convenience, since it could also be set by optimizer__lr instead. However, it is used so often that we provided this shortcut. If you set both lr and optimizer__lr, the latter have precedence.

max_epochs

The maximum number of epochs to train with each fit() call. When you call fit(), the net will train for this many epochs, except if you interrupt training before the end (e.g. by using an early stopping callback or interrupt manually with ctrl+c).

If you want to change the number of epochs to train, you can either set a different value for max_epochs, or you call fit_loop() instead of fit() and pass the desired number of epochs explicitly:

net.fit_loop(X, y, epochs=20)
batch_size

This argument controls the batch size for iterator_train and iterator_valid at the same time. batch_size=128 is thus a convenient shortcut for explicitly typing iterator_train__batch_size=128 and iterator_valid__batch_size=128. If you set all three arguments, the latter two will have precedence.

train_split

This determines the NeuralNet’s internal train/validation split. By default, 20% of the incoming data is reserved for validation. If you set this value to None, all the data is used for training.

For more details, please look at dataset.

callbacks

For more details on the callback classes, please look at callbacks.

By default, NeuralNet and its subclasses start with a couple of useful callbacks. Those are defined in the get_default_callbacks() method and include, for instance, callbacks for measuring and printing model performance.

In addition to the default callbacks, you may provide your own callbacks. There are a couple of ways to pass callbacks to the NeuralNet instance. The easiest way is to pass a list of all your callbacks to the callbacks argument:

net = NeuralNet(
    module=MyModule,
    callbacks=[
        MyCallback1(...),
        MyCallback2(...),
    ],
)

Inside the NeuralNet instance, each callback will receive a separate name. Since we provide no name in the example above, the class name will taken, which will lead to a name collision in case of two or more callbacks of the same class. This is why it is better to initialize the callbacks with a list of tuples of name and callback instance, like this:

net = NeuralNet(
    module=MyModule,
    callbacks=[
        ('cb1', MyCallback1(...)),
        ('cb2', MyCallback2(...)),
    ],
)

This approach of passing a list of name, instance tuples should be familiar to users of sklearnPipelines and FeatureUnions.

An additonal advantage of this way of passing callbacks is that it allows to pass arguments to the callbacks by name (using the double-underscore notation):

net.set_params(callbacks__cb1__foo=123, callbacks__cb2__bar=456)

Use this, for instance, when trying out different callback parameters in a grid search.

Note: The user-defined callbacks are always called in the same order as they appeared in the list. If there are dependencies between the callbacks, the user has to make sure that the order respects them. Also note that the user-defined callbacks will be called after the default callbacks so that they can make use of the things provided by the default callbacks. The only exception is the default callback PrintLog, which is always called last.

warm_start

This argument determines whether each fit() call leads to a re-initialization of the NeuralNet or not. By default, when calling fit(), the parameters of the net are initialized, so your previous training progress is lost (consistent with the sklearn fit() calls). In contrast, with warm_start=True, each fit() call will continue from the most recent state.

device

As the name suggests, this determines which computation device should be used. If set to cuda, the incoming data will be transferred to CUDA before being passed to the PyTorch Module. The device parameter adheres to the general syntax of the PyTorch device parameter.

initialize()

As mentioned earlier, upon instantiating the NeuralNet instance, the net’s components are not yet initialized. That means, e.g., that the weights and biases of the layers are not yet set. This only happens after the initialize() call. However, when you call fit() and the net is not yet initialized, initialize() is called automatically. You thus rarely need to call it manually.

The initialize() method itself calls a couple of other initialization methods that are specific to each component. E.g., initialize_module() is responsible for initializing the PyTorch module. Therefore, if you have special needs for initializing the module, it is enough to override initialize_module(), you don’t need to override the whole initialize() method.

fit(X, y)

This is one of the main methods you will use. It contains everything required to train the model, be it batching of the data, triggering the callbacks, or handling the internal validation set.

In general, we assume there to be an X and a y. If you have more input data than just one array, it is possible for X to be a list or dictionary of data (see dataset). And if your task does not have an actual y, you may pass y=None.

If you fit with a PyTorch Dataset and don’t explicitly pass y, several components down the line might not work anymore, since sklearn sometimes requires an explicit y (e.g. for scoring). In general, PyTorch Datasets should work, though.

In addition to fit(), there is also the partial_fit() method, known from some sklearn estimators. partial_fit() allows you to continue training from your current status, even if you set warm_start=False. A further use case for partial_fit() is when your data does not fit into memory and you thus need to have several training steps.

Tip : skorch gracefully catches the KeyboardInterrupt exception. Therefore, during a training run, you can send a KeyboardInterrupt signal without the Python process exiting (typically, KeyboardInterrupt can be triggered by ctrl+c or, in a Jupyter notebook, by clicking Kernel -> Interrupt). This way, when your model has reached a good score before max_epochs have been reached, you can dynamically stop training.

predict(X) and predict_proba(X)

These methods perform an inference step on the input data and return numpy.ndarrays. By default, predict_proba() will return whatever it is that the module’s forward() method returns, cast to a numpy.ndarray. If forward() returns multiple outputs as a tuple, only the first output is used, the rest is discarded.

If the forward()-output can not be cast to a numpy.ndarray, or if you need access to all outputs in the multiple-outputs case, consider using either of forward() or forward_iter() methods to generate outputs from the module. Alternatively, you may directly call net.module_(X).

In case of NeuralNetClassifier, the predict() method tries to return the class labels by applying the argmax over the last axis of the result of predict_proba(). Obviously, this only makes sense if predict_proba() returns class probabilities. If this is not true, you should just use predict_proba().

saving and loading

skorch provides two ways to persist your model. First it is possible to store the model using Python’s pickle function. This saves the whole model, including hyperparameters. This is useful when you don’t want to initialize your model before loading its parameters, or when your NeuralNet is part of an sklearn Pipeline:

net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)

model = Pipeline([
    ('my-features', get_features()),
    ('net', net),
])
model.fit(X, y)

# saving
with open('some-file.pkl', 'wb') as f:
    pickle.dump(model, f)

# loading
with open('some-file.pkl', 'rb') as f:
    model = pickle.load(f)

The disadvantage of pickling is that if your underlying code changes, unpickling might raise errors. Also, some Python code (e.g. lambda functions) cannot be pickled.

For this reason, we provide a second method for persisting your model. To use it, call the save_params() and load_params() method on NeuralNet. Under the hood, this saves the module’s state_dict, i.e. only the weights and biases of the module. This is more robust to changes in the code but requires you to initialize a NeuralNet to load the parameters again:

net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)

model = Pipeline([
    ('my-features', get_features()),
    ('net', net),
])
model.fit(X, y)

net.save_params(f_params='some-file.pkl')

new_net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)
new_net.initialize()  # This is important!
new_net.load_params(f_params='some-file.pkl')

In addition to saving the model parameters, the history and optimizer state can be saved by including the f_history and f_optimizer keywords to save_params() and load_params() on NeuralNet. This feature can be used to continue training:

net = NeuralNet(
    module=MyModule
    criterion=torch.nn.NLLLoss,
)

net.fit(X, y, epochs=2) # Train for 2 epochs

net.save_params(
    f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')

new_net = NeuralNet(
    module=MyModule
    criterion=torch.nn.NLLLoss,
)
new_net.initialize() # This is important!
new_net.load_params(
    f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')

new_net.fit(X, y, epochs=2) # Train for another 2 epochs

Note

In order to use this feature, the history must only contain JSON encodable Python data structures. Numpy and PyTorch types should not be in the history.

Special arguments

In addition to the arguments explicitly listed for NeuralNet, there are some arguments with special prefixes, as shown below:

class MyModule(torch.nn.Module):
    def __init__(self, num_units, dropout):
        ...

net = NeuralNet(
    module=MyModule,
    module__num_units=100,
    module__dropout=0.5,
    criterion=torch.nn.NLLLoss,
    criterion__weight=weight,
    optimizer=torch.optim.SGD,
    optimizer__momentum=0.9,
)

Those arguments are used to initialize your module, criterion, etc. They are not fixed because we cannot know them in advance; in fact, you can define any parameter for your module or other components.

All special prefixes are stored in the prefixes_ class attribute of NeuralNet. Currently, they are:

  • module
  • iterator_train
  • iterator_valid
  • optimizer
  • criterion
  • callbacks
  • dataset

Subclassing NeuralNet

Apart from the NeuralNet base class, we provide NeuralNetClassifier and NeuralNetRegressor for typical classification and regressions tasks. They should work as drop-in replacements for sklearn classifiers and regressors.

The NeuralNet class is a little less opinionated about the incoming data, e.g. it does not determine a loss function by default. Therefore, if you want to write your own subclass for a special use case, you would typically subclass from NeuralNet.

skorch aims at making subclassing as easy as possible, so that it doesn’t stand in your way. For instance, all components (module, optimizer, etc.) have their own initialization method (initialize_module, initialize_optimizer, etc.). That way, if you want to modify the initialization of a component, you can easily do so.

Additonally, NeuralNet has a couple of get_* methods for when a component is retrieved repeatedly. E.g., get_loss() is called when the loss is determined. Below we show an example of overriding get_loss() to add L1 regularization to our total loss:

class RegularizedNet(NeuralNet):
    def __init__(self, *args, lambda1=0.01, **kwargs):
        super().__init__(*args, **kwargs)
        self.lambda1 = lambda1

    def get_loss(self, y_pred, y_true, X=None, training=False):
        loss = super().get_loss(y_pred, y_true, X=X, training=training)
        loss += self.lambda1 * sum([w.abs().sum() for w in self.module_.parameters()])
        return loss

Note

This example also regularizes the biases, which you typically don’t need to do.

Callbacks

Callbacks provide a flexible way to customize the behavior of your NeuralNet training without the need to write subclasses.

You will often find callbacks writing to or reading from the history attribute. Therefore, if you would like to log the net’s behavior or do something based on the past behavior, consider using net.history.

This page will not explain all existing callbacks. For that, please look at skorch.callbacks.

Callback base class

The base class for each callback is Callback. If you would like to write your own callbacks, you should inherit from this class. A guide and practical example on how to write your own callbacks is shown in this notebook. In general, remember this:

  • They should inherit from the base class.
  • They should implement at least one of the on_ methods provided by the parent class (see below).
  • As argument, the methods first get the NeuralNet instance, and, where appropriate, the local data (e.g. the data from the current batch). The method should also have **kwargs in the signature for potentially unused arguments.

Callback methods to override

The following methods could potentially be overriden when implementing your own callbacks.

initialize()

If you have attributes that should be reset when the model is re-initialized, those attributes should be set in this method.

on_train_begin(net, X, y)

Called once at the start of the training process (e.g. when calling fit).

on_train_end(net, X, y)

Called once at the end of the training process.

on_epoch_begin(net, dataset_train, dataset_valid)

Called once at the start of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.

on_epoch_end(net, dataset_train, dataset_valid)

Called once at the end of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.

on_batch_begin(net, Xi, yi, training)

Called once before each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input. Also includes a bool indicating if this is a training batch or not.

on_batch_end(net, Xi, yi, training, loss, y_pred)

Called once after each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input.

on_grad_computed(net, named_parameters, Xi, yi)

Called once per batch after gradients have been computed but before an update step was performed. Gets the module parameters as additional input as well as the batch data. Useful if you want to tinker with gradients.

Setting callback parameters

You can set specific callback parameters using the ususal set_params interface on the network by using the callbacks__ prefix and the callback’s name. For example to change the scoring order of the train loss you can write this:

net = NeuralNet()
net.set_params(callbacks__train_loss__lower_is_better=False)

Changes will be applied on initialization and callbacks that are changed using set_params will be re-initialized.

The name you use to address the callback can be chosen during initialization of the network and defaults to the class name. If there is a conflict, the conflicting names will be made unique by appending a count suffix starting at 1, e.g. EpochScoring_1, EpochScoring_2, etc.

Deactivating callbacks

If you would like to (temporarily) deactivate a callback, you can do so by setting its parameter to None. E.g., if you have a callback called ‘my_callback’, you can deactivate it like this:

net = NeuralNet(
    module=MyModule,
        callbacks=[('my_callback', MyCallback())],
)
# now deactivate 'my_callback':
net.set_params(callbacks__my_callback=None)

This also works with default callbacks.

Deactivating callbacks can be especially useful when you do a parameter search (say with sklearn GridSearchCV). If, for instance, you use a callback for learning rate scheduling (e.g. via LRScheduler) and want to test its usefulness, you can compare the performance once with and once without the callback.

Scoring

skorch provides two scoring callbacks by default, EpochScoring and BatchScoring. They work basically in the same way, except that EpochScoring calculates scores after each epoch and BatchScoring after each batch. Use the former if averaging of batch-wise scores is imprecise (say for AUC score) and the latter if you are very tight for memory.

In general, the scoring callbacks are useful when the default scores determined by the NeuralNet are not enough. They allow you to easily add new metrics to be logged during training. For an example of how to add a new score to your model, look at this notebook.

The first argument to both callbacks is name and should be a string. This determines the column name of the score shown by the PrintLog after each epoch.

Next comes the scoring parameter. For eager sklearn users, this should be familiar, since it works exactly the same as in sklearn GridSearchCV, RandomizedSearchCV, cross_val_score(), etc. For those who are unfamiliar, here is a short explanation:

  • If you pass a string, sklearn makes a look-up for a score with that name. Examples would be 'f1' and 'roc_auc'.
  • If you pass None, the model’s score method is used. By default, NeuralNet and its subclasses don’t provide a score method, but you can easily implement your own. If you do, it should take X and y (the target) as input and return a scalar as output.
  • Finally, you can pass a function/callable. In that case, this function should have the signature func(net, X, y) and return a scalar.

More on sklearn’s model evaluation can be found in this notebook.

The lower_is_better parameter determines whether lower scores should be considered as better (e.g. log loss) or worse (e.g. accuracy). This information is used to write a <name>_best value to the net’s history. E.g., if your score is f1 score and is called 'f1', you should set lower_is_better=False. The history will then contain an entry for 'f1', which is the score itself, and an entry for 'f1_best', which says whether this is the as of yet best f1 score.

on_train is used to indicate whether training or validation data should be used to determine the score. By default, it is set to validation.

Finally, you may have to provide your own target_extractor. This should be a function or callable that is applied to the target before it is passed to the scoring function. The main reason why we need this is that sometimes, the target is not of a form expected by sklearn and we need to process it before passing it on.

Checkpoint

The Checkpoint callback creates a checkpoint of your model after each epoch that met certain criteria. By default, the condition is that the validation loss has improved, however you may change this by specifying the monitor parameter. It can take three types of arguments:

  • None: The model is saved after each epoch;
  • string: The model checks whether the last entry in the model history for that key is truthy. This is useful in conjunction with scores determined by a scoring callback. They write a <score>_best entry to the history, which can be used for checkpointing. By default, the Checkpoint callback looks at 'valid_loss_best';
  • function or callable: In that case, the function should take the NeuralNet instance as sole input and return a bool as output.

To specify where and how your model is saved, change the arguments starting with f_:

  • f_params: to save model parameters
  • f_optimizer: to save optimizer state
  • f_history: to save training history
  • f_pickle: to pickle the entire model object.

Please refer to Saving and Loading for more information about restoring your network from a checkpoint.

Dataset

This module contains classes and functions related to data handling.

CVSplit

This class is responsible for performing the NeuralNet’s internal cross validation. For this, it sticks closely to the sklearn standards. For more information on how sklearn handles cross validation, look here.

The first argument that CVSplit takes is cv. It works analogously to the cv argument from sklearn GridSearchCV, cross_val_score(), etc. For those not familiar, here is a short explanation of what you may pass:

  • None: Use the default 3-fold cross validation.
  • integer: Specifies the number of folds in a (Stratified)KFold,
  • float: Represents the proportion of the dataset to include in the validation split (e.g. 0.2 for 20%).
  • An object to be used as a cross-validation generator.
  • An iterable yielding train, validation splits.

Furthermore, CVSplit takes a stratified argument that determines whether a stratified split should be made (only makes sense for discrete targets), and a random_state argument, which is used in case the cross validation split has a random component.

One difference to sklearn’s cross validation is that skorch makes only a single split. In sklearn, you would expect that in a 5-fold cross validation, the model is trained 5 times on the different combination of folds. This is often not desirable for neural networks, since training takes a lot of time. Therefore, skorch only ever makes one split.

If you would like to have all splits, you can still use skorch in conjunction with the sklearn functions, as you would do with any other sklearn-compatible estimator. Just remember to set train_split=None, so that the whole dataset is used for training. Below is shown an example of making out-of-fold predictions with skorch and sklearn:

net = NeuralNetClassifier(
    module=MyModule,
    train_split=None,
)

from sklearn.model_selection import cross_val_predict

y_pred = cross_val_predict(net, X, y, cv=5)

Dataset

In PyTorch, we have the concept of a Dataset and a DataLoader. The former is purely the container of the data and only needs to implement __len__() and __getitem__(<int>). The latter does the heavy lifting, such as sampling, shuffling, and distributed processing.

skorch uses the PyTorch DataLoaders by default. However, the Datasets provided by PyTorch are not sufficient for our usecase; for instance, they don’t work with numpy.ndarrays. That’s why we provide our own Dataset class. This container works with:

In addition, you can pass dictionaries or lists of one of those data types, e.g. a dictionary of numpy.ndarrays. When you pass dictionaries, the keys of the dictionaries are used as the argument name for the forward() method of the net’s module. Similarly, the column names of pandas DataFrames are used as argument names. The example below should illustrate how to use this feature:

import numpy as np
import torch
import torch.nn.functionl as F

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.dense_a = torch.nn.Linear(10, 100)
        self.dense_b = torch.nn.Linear(20, 100)
        self.output = torch.nn.Linear(200, 2)

    def forward(self, key_a, key_b):
        hid_a = F.relu(self.dense_a(key_a))
        hid_b = F.relu(self.dense_b(key_b))
        concat = torch.cat((hid_a, hid_b), dim=1)
        out = F.softmax(self.output(concat))
        return out

net = NeuralNetClassifier(MyModule)

X = {
    'key_a': np.random.random((1000, 10)).astype(np.float32),
    'key_b': np.random.random((1000, 20)).astype(np.float32),
}
y = np.random.randint(0, 2, size=1000)

net.fit(X, y)

Note that the keys in the dictionary X exactly match the argument names in the forward() method. This way, you can easily work with several different types of input features.

The Dataset from skorch makes the assumption that you always have an X and a y, where X represents the input data and y the target. However, you may leave y=None, in which case Dataset returns a dummy variable.

Dataset applies a transform final transform on the data before passing it on to the PyTorch DataLoader. By default, it replaces y by a dummy variable in case it is None. If you would like to apply your own transformation on the data, you should subclass Dataset and override the transform() method, then pass your custom class to NeuralNet as the dataset argument.

Saving and Loading

Skorch provides callbacks: Checkpoint, TrainEndCheckpoint, and LoadInitState to handle saving and loading models during training. To demonstrate these features, we generate a dataset and create a simple module:

import numpy as np
from sklearn.datasets import make_classification
from torch import nn

from skorch import NeuralNetClassifier

X, y = make_classification(1000, 10, n_informative=5, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Sequential):
    def __init__(self, num_units=10):
        super().__init__(
            nn.Linear(10, num_units),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(num_units, 10),
            nn.Linear(10, 2),
            nn.Softmax(dim=-1)
        )

We create a checkpoint setting dirname to 'exp1'. This will configure the checkpoint to save the model parameters, optimizer, and history into a directory named 'exp1'.

from skorch.callbacks import Checkpoint, TrainEndCheckpoint

cp = Checkpoint(dirname='exp1')
final_cp = TrainEndCheckpoint(dirname='exp1')
net = NeuralNetClassifier(
    MyModule, lr=0.5, callbacks=[cp, final_cp]
)

_ = net.fit(X, y)

# prints
  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        0.6200       0.8209        0.4765     +  0.0232
      2        0.3644       0.8557        0.3474     +  0.0238
      3        0.2875       0.8806        0.3201     +  0.0214
      4        0.2514       0.8905        0.3080     +  0.0237
      5        0.2333       0.9154        0.2844     +  0.0203
      6        0.2177       0.9403        0.2164     +  0.0215
      7        0.2194       0.9403        0.2159     +  0.0220
      8        0.2027       0.9403        0.2299        0.0202
      9        0.1864       0.9254        0.2313        0.0196
     10        0.2024       0.9353        0.2333        0.0221

By default, the checkpoint observes valid_loss and will save the model when the valid_loss is lowest. This can be seen by the + mark in the cp column of the logs.

On our first run, the validation loss did not improve after the 7th epoch. We can continue training from this checkpoint with a lower learning rate by using LoadInitState:

from skorch.callbacks import LoadInitState

cp = Checkpoint(dirname='exp1')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
    MyModule, lr=0.1, callbacks=[cp, load_state]
)

_ = net.fit(X, y)

# prints

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      8        0.1939       0.9055        0.2626     +  0.0238
      9        0.2055       0.9353        0.2031     +  0.0239
     10        0.1992       0.9453        0.2101        0.0182
     11        0.2033       0.9453        0.1947     +  0.0211
     12        0.1825       0.9104        0.2515        0.0185
     13        0.2010       0.9453        0.1927     +  0.0187
     14        0.1508       0.9453        0.1952        0.0198
     15        0.1679       0.9502        0.1905     +  0.0181
     16        0.1516       0.9453        0.1864     +  0.0192
     17        0.1576       0.9453        0.1804     +  0.0184

Since we started from the previous checkpoint which ended at epoch 7, the second run starts at epoch 8, continuing from the first checkpoint. With a lower learning rate, the validation loss was able to improve!

Notice in the first run, we included a TrainEndCheckpoint in the callbacks. This checkpoint saves the model at the end of training. This checkpoint can be passed to LoadInitState to continue training:

cp_from_final = Checkpoint(dirname='exp1', fn_prefix='from_final_')
load_state = LoadInitState(final_cp)
net = NeuralNetClassifier(
    MyModule, lr=0.1, callbacks=[cp_from_final, load_state]
)

_ = net.fit(X, y)

In this run, training started at epoch 11, continuing from the end of the first run which ended at epoch 10. We created a new checkpoint with fn_prefix set to 'from_final' to prefix the saved filenames with 'from_final' to make sure this checkpoint does not override the validation checkpoint.

Since our MyModule class allows num_units to be adjusted, we can start a new experiment by changing the dirname:

cp = Checkpoint(dirname='exp2')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
    MyModule, lr=0.5,
    callbacks=[cp, load_state],
    module__num_units=20,
)

_ = net.fit(X, y)

# print

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        0.5256       0.8856        0.3624     +  0.0181
      2        0.2956       0.8756        0.3416     +  0.0222
      3        0.2280       0.9453        0.2299     +  0.0211
      4        0.1948       0.9303        0.2136     +  0.0232
      5        0.1800       0.9055        0.2696        0.0223
      6        0.1605       0.9403        0.1906     +  0.0190
      7        0.1594       0.9403        0.2027        0.0184
      8        0.1319       0.9303        0.1910        0.0220
      9        0.1558       0.9254        0.1923        0.0189
     10        0.1432       0.9303        0.2219        0.0192

This stores the model into the 'exp2' directory. Since this is the first run, the LoadInitState callback does not do anything. If we were to run the above script again, the LoadInitState callback will load the model from the checkpoint.

In the run above, the last checkpoint was created at epoch 6, we can load this checkpoint to predict with it:

net = NeuralNetClassifier(
    MyModule, lr=0.5, module__num_units=20,
)
net.initialize()
net.load_params(checkpoint=cp)

y_pred = net.predict(X)

In this case, it is important to initialize the neutral net before running NeutralNet.load_params().

History

A NeuralNet object logs training progress internally using a History object, stored in the history attribute. Among other use cases, history is used to print the training progress after each epoch:

net.fit(X, y)

# prints
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.7111       0.5100        0.6894  0.1345
      2        0.6928       0.5500        0.6803  0.0608
      3        0.6833       0.5650        0.6741  0.0620
      4        0.6763       0.5850        0.6674  0.0594

All this information (and more) is stored in and can be accessed through net.history. It is thus best practice to make use of history for storing training-related data.

In general, History works like a list of dictionaries, where each item in the list corresponds to one epoch, and each key of the dictionary to one column. Thus, if you would like to access the 'train_loss' of the last epoch, you can call net.history[-1]['train_loss']. To make the history more accessible, though, it is possible to just pass the indices separated by a comma: net.history[-1, 'train_loss'].

Moreover, History stores the results from each individual batch under the batches key during each epoch. So to get the train loss of the 3rd batch of the 7th epoch, use net.history[7, 'batches', 3, 'train_loss'].

Here are some examples showing how to index history:

# history of a fitted neural net
history = net.history
# get current epoch, a dict
history[-1]
# get train losses from all epochs, a list of floats
history[:, 'train_loss']
# get train and valid losses from all epochs, a list of tuples
history[:, ('train_loss', 'valid_loss')]
# get current batches, a list of dicts
history[-1, 'batches']
# get latest batch, a dict
history[-1, 'batches', -1]
# get train losses from current batch, a list of floats
history[-1, 'batches', :, 'train_loss']
# get train and valid losses from current batch, a list of tuples
history[-1, 'batches', :, ('train_loss', 'valid_loss')]

As History essentially is a list of dictionaries, you can also write to it as if it were a list of dictionaries. Here too, skorch provides some convenience functions to make life easier. First there is new_epoch(), which will add a new epoch dictionary to the end of the list. Also, there is new_batch() for adding new batches to the current epoch.

To add a new item to the current epoch, use history.record('foo', 123). This will set the value 123 for the key foo of the current epoch. To write a value to the current batch, use history.record_batch('bar', 456). Below are some more examples:

# history of a fitted neural net
history = net.history
# add new epoch row
history.new_epoch()
# add an entry to current epoch
history.record('my-score', 123)
# add a batch row to the current epoch
history.new_batch()
# add an entry to the current batch
history.record_batch('my-batch-score', 456)
# overwrite entry of current batch
history.record_batch('my-batch-score', 789)

Toy

This module contains helper functions and classes that allow you to prototype quickly or that can be used for writing tests.

MLPModule

MLPModule is a simple PyTorch Module that implements a multi-layer perceptron. It allows to indicate the number of input, hidden, and output units, as well as the non-linearity and use of dropout. You can use this module directly in conjunction with NeuralNet.

Additionally, the functions make_classifier(), make_binary_classifier(), and make_regressor() can be used to return a MLPModule with the defaults adjusted for use in multi-class classification, binary classification, and regression, respectively.

Helper

This module provides helper functions and classes for the user. They make working with skorch easier but are not used by skorch itself.

SliceDict

A SliceDict is a wrapper for Python dictionaries that makes them behave a little bit like numpy.ndarrays. That way, you can slice your dictionary across values, len() will show the length of the arrays and not the number of keys, and you get a shape attribute. This is useful because if your data is in a dict, you would normally not be able to use sklearn GridSearchCV and similar things; with SliceDict, this works.

REST Service

In this section we’ll take the RNN sentiment classifer from the example Predicting sentiment on the IMDB dataset and use it to demonstrate how to easily expose your PyTorch module on the web using skorch and another library called Palladium.

With Palladium, you define the Palladium dataset, the model, and Palladium provides the framework to fit, test, and serve your model on the web. Palladium comes with its own documentation and a tutorial, which you may want to check out to learn more about what you can do with it.

The way to make the dataset and model known to Palladium is through its configuration file. Here’s the part of the configuration that defines the dataset and model:

{
    'dataset_loader_train': {
        '__factory__': 'model.DatasetLoader',
        'path': 'aclImdb/train/',
    },

    'dataset_loader_test': {
        '__factory__': 'model.DatasetLoader',
        'path': 'aclImdb/test/',
    },

    'model': {
        '__factory__': 'model.create_pipeline',
        'use_cuda': True,
    },

    'model_persister': {
        '__factory__': 'palladium.persistence.File',
        'path': 'rnn-model-{version}',
    },

    'scoring': 'accuracy',
}

You can save this configuration as palladium-config.py.

The dataset_loader_train and dataset_loader_test entries define where the data comes from. They refer to a Python class defined inside the model module. Let’s create a file and call it model.py, put it in the same directory as the configuration file. We’ll start off with defining the dataset loader:

import os
from urllib.request import urlretrieve
import tarfile

import numpy as np
from palladium.interfaces import DatasetLoader as IDatasetLoader
from sklearn.datasets import load_files

DATA_URL = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
DATA_FN = DATA_URL.rsplit('/', 1)[1]


def download():
    if not os.path.exists('aclImdb'):
        # unzip data if it does not exist
        if not os.path.exists(DATA_FN):
            urlretrieve(DATA_URL, DATA_FN)
        with tarfile.open(DATA_FN, 'r:gz') as f:
            f.extractall()


class DatasetLoader(IDatasetLoader):
    def __init__(self, path='aclImdb/train/'):
        self.path = path

    def __call__(self):
        download()
        dataset = load_files(self.path, categories=['pos', 'neg'])
        X, y = dataset['data'], dataset['target']
        X = np.asarray([x.decode() for x in X])  # decode from bytes
        return X, y

The most interesting bit here is that our Palladium DatasetLoader defines a __call__ method that will return the data and the target (X and y). Easy. Note that in the configuration file, we refer to our DatasetLoader twice, once for the training set and once for the test set.

Our configuration also refers to a function create_pipeline which we’ll create next:

from dstoolbox.transformers import Padder2d
from dstoolbox.transformers import TextFeaturizer
from sklearn.pipeline import Pipeline
from skorch import NeuralNetClassifier
import torch


def create_pipeline(
    vocab_size=1000,
    max_len=50,
    use_cuda=False,
    **kwargs
):
    return Pipeline([
        ('to_idx', TextFeaturizer(max_features=vocab_size)),
        ('pad', Padder2d(max_len=max_len, pad_value=vocab_size, dtype=int)),
        ('net', NeuralNetClassifier(
            RNNClassifier,
            device=('cuda' if use_cuda else 'cpu'),
            max_epochs=5,
            lr=0.01,
            optimizer=torch.optim.RMSprop,
            module__vocab_size=vocab_size,
            **kwargs,
        ))
    ])

You’ve noticed that this function’s job is to create the model and return it. Here, we’re defining a pipeline that wraps skorch’s NeuralNetClassifier, which in turn is a wrapper around our PyTorch module, as it’s defined in the predicting sentiment tutorial. We’ll also add the RNNClassifier to model.py:

from torch import nn
F = nn.functional


class RNNClassifier(nn.Module):
    def __init__(
        self,
        embedding_dim=128,
        rec_layer_type='lstm',
        num_units=128,
        num_layers=2,
        dropout=0,
        vocab_size=1000,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.rec_layer_type = rec_layer_type.lower()
        self.num_units = num_units
        self.num_layers = num_layers
        self.dropout = dropout

        self.emb = nn.Embedding(
            vocab_size + 1, embedding_dim=self.embedding_dim)

        rec_layer = {'lstm': nn.LSTM, 'gru': nn.GRU}[self.rec_layer_type]
        # We have to make sure that the recurrent layer is batch_first,
        # since sklearn assumes the batch dimension to be the first
        self.rec = rec_layer(
            self.embedding_dim, self.num_units,
            num_layers=num_layers, batch_first=True,
            )

        self.output = nn.Linear(self.num_units, 2)

    def forward(self, X):
        embeddings = self.emb(X)
        # from the recurrent layer, only take the activities from the
        # last sequence step
        if self.rec_layer_type == 'gru':
            _, rec_out = self.rec(embeddings)
        else:
            _, (rec_out, _) = self.rec(embeddings)
        rec_out = rec_out[-1]  # take output of last RNN layer
        drop = F.dropout(rec_out, p=self.dropout)
        # Remember that the final non-linearity should be softmax, so
        # that our predict_proba method outputs actual probabilities!
        out = F.softmax(self.output(drop), dim=-1)
        return out

You can find the full contents of the model.py file in the skorch/examples/rnn_classifer folder of skorch’s source code.

Now with dataset and model in place, it’s time to try Palladium out. You can install Palladium and another dependency we use with pip install palladium dstoolbox.

From within the directory that contains model.py and palladium-config.py now run the following command:

PALLADIUM_CONFIG=palladium-config.py pld-fit --evaluate

You should see output similar to this:

INFO:palladium:Loading data...
INFO:palladium:Loading data done in 0.607 sec.
INFO:palladium:Fitting model...
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        0.7679       0.5008        0.7617  3.1300
      2        0.6385       0.7100        0.5840  3.1247
      3        0.5430       0.7438        0.5518  3.1317
      4        0.4736       0.7480        0.5424  3.1373
      5        0.4253       0.7448        0.5832  3.1433
INFO:palladium:Fitting model done in 29.060 sec.
DEBUG:palladium:Evaluating model on train set...
INFO:palladium:Train score: 0.83068
DEBUG:palladium:Evaluating model on train set done in 6.743 sec.
DEBUG:palladium:Evaluating model on test set...
INFO:palladium:Test score:  0.75428
DEBUG:palladium:Evaluating model on test set done in 6.476 sec.
INFO:palladium:Writing model...
INFO:palladium:Writing model done in 0.694 sec.
INFO:palladium:Wrote model with version 1.

Congratulations, you’ve trained your first model with Palladium! Note that in the output you see a train score (accuracy) of 0.83 and a test score of about 0.75. These refer to how well your model did on the training set (defined by dataset_loader_train in the configuration) and on the test set (dataset_loader_test).

You’re ready to now serve the model on the web. Add this piece of configuration to the palladium-config.py configuration file (and make sure it lives within the outermost brackets:

{
    # ...

    'predict_service': {
        '__factory__': 'palladium.server.PredictService',
        'mapping': [
            ('text', 'str'),
        ],
        'predict_proba': True,
        'unwrap_sample': True,
    },

    # ...
}

With this piece of information inside the configuration, we’re ready to launch the web server using:

PALLADIUM_CONFIG=palladium-config.py pld-devserver

You can now try out the web service at this address: http://localhost:5000/predict?text=this+movie+was+brilliant

You should see a JSON string returned that looks something like this:

{
    "metadata": {"error_code": 0, "status": "OK"},
    "result": [0.326442807912827, 0.673557221889496],
}

The result entry has the probabilities. Our model assigns 67% probability to the sentence “this movie was brilliant” to be positive. By the way, the skorch tutorial itself has tips on how to improve this model.

The take away is Palladium helps you reduce the boilerplate code that’s needed to get your machine learning project started. Palladium has routines to fit, test, and serve models so you don’t have to worry about that, and you can concentrate on the actual machine learning part. Configuration and code are separated with Palladium, which helps organize your experiments and work on ideas in parallel. Check out the Palladium documentation for more.

Parallelism

Skorch supports distributing work among a cluster of workers via dask.distributed. In this section we’ll describe how to use Dask to efficiently distribute a grid search or a randomized search on hyperparamerers across multiple GPUs and potentially multiple hosts.

Let’s assume that you have two GPUs that you want to run a hyper parameter search on.

The key here is using the CUDA environment variable CUDA_VISIBLE_DEVICES to limit which devices are visible to our CUDA application. We’ll set up Dask workers that, using this environment variable, each see one GPU only. On the PyTorch side, we’ll have to make sure to set the device to cuda when we initialize the NeuralNet class.

Let’s run through the steps. First, install Dask and dask.distrubted:

pip install dask distributed

Next, assuming you have two GPUs on your machine, let’s start up a Dask scheduler and two Dask workers. Make sure the Dask workers are started up in the right environment, that is, with access to all packages required to do the work:

dask-scheduler
CUDA_VISIBLE_DEVICES=0 dask-worker 127.0.0.1:8786 --nthreads 1
CUDA_VISIBLE_DEVICES=1 dask-worker 127.0.0.1:8786 --nthreads 1

In your code, use joblib’s parallel_backend to choose the Dask backend for grid searches and the like. Remember to also import the distributed.joblib module, as that will register the joblib backend. Let’s see how this could look like:

import distributed.joblib  # imported for side effects
from sklearn.externals.joblib import parallel_backend

X, y = load_my_data()
model = get_that_model()

gs = GridSearchCV(
    model,
    param_grid={'net__lr': [0.01, 0.03]},
    scoring='accuracy',
    )
with parallel_backend('dask.distributed', scheduler_host='127.0.0.1:8786'):
    gs.fit(X, y)
print(gs.cv_results_)

You can also use Palladium to do the job. An example is included in the source in the examples/rnn_classifier folder. Change in there and run the following command, after having set up your Dask workers:

PALLADIUM_CONFIG=palladium-config.py,dask-config.py pld-grid-search

FAQ

How do I apply L2 regularization?

To apply L2 regularization (aka weight decay), PyTorch supplies the weight_decay parameter, which must be supplied to the optimizer. To pass this variable in skorch, use the double-underscore notation for the optimizer:

net = NeuralNet(
    ...,
    optimizer__weight_decay=0.01,
)

How can I continue training my model?

By default, when you call fit() more than once, the training starts from zero instead of from where it was left. This is in line with sklearn’s behavior but not always desired. If you would like to continue training, use partial_fit() instead of fit(). Alternatively, there is the warm_start argument, which is False by default. Set it to True instead and you should be fine.

How do I shuffle my train batches?

skorch uses DataLoader from PyTorch under the hood. This class takes a couple of arguments, for instance shuffle. We therefore need to pass the shuffle argument to DataLoader, which we achieve by using the double-underscore notation (as known from sklearn):

net = NeuralNet(
    ...,
    iterator_train__shuffle=True,
)

Note that we have an iterator_train for the training data and an iterator_valid for validation and test data. In general, you only want to shuffle the train data, which is what the code above does.

How do I use sklearn GridSeachCV when my data is in a dictionary?

skorch supports dicts as input but sklearn doesn’t. To get around that, try to wrap your dictionary into a SliceDict. This is a data container that partly behaves like a dict, partly like an ndarray. For more details on how to do this, have a look at the coresponding data section in the notebook.

I want to use sample_weight, how can I do this?

Some scikit-learn models support to pass a sample_weight argument to fit calls as part of the fit_params. This allows you to give different samples different weights in the final loss calculation.

In general, skorch supports fit_params, but unfortunately just calling net.fit(X, y, sample_weight=sample_weight) is not enough, because the fit_params are not split into train and valid, and are not batched, resulting in a mismatch with the training batches.

Fortunately, skorch supports passing dictionaries as arguments, which are actually split into train and valid and then batched. Therefore, the best solution is to pass the sample_weight with X as a dictionary. Below, there is example code on how to achieve this:

X, y = get_data()
# put your X into a dict if not already a dict
X = {'data': X}
# add sample_weight to the X dict
X['sample_weight'] = sample_weight

class MyModule(nn.Module):
    ...
    def forward(self, data, sample_weight):
        # when X is a dict, its keys are passed as kwargs to forward, thus
        # our forward has to have the arguments 'data' and 'sample_weight';
        # usually, sample_weight can be ignored here
        ...

class MyNet(NeuralNet):
    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        # override get_loss to use the sample_weight from X
        loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
        sample_weight = X['sample_weight']
        loss_reduced = (sample_weight * loss_unreduced).mean()
        return loss_reduced

# make sure to pass reduce=False to your criterion, since we need the loss
# for each sample so that it can be weighted
net = MyNet(MyModule, ..., criterion__reduce=False)
net.fit(X, y)

I already split my data into training and validation sets, how can I use them?

If you have predefined training and validation datasets that are subclasses of PyTorch Dataset, you can use predefined_split() to wrap your validation dataset and pass it to NeuralNet’s train_split parameter:

from skorch.helper import predefined_split

net = NeuralNet(
    ...,
    train_split=predefined_split(valid_ds)
)
net.fit(train_ds)

If you split your data by using train_test_split(), you can create your own skorch Dataset, and then pass it to predefined_split():

from sklearn.model_selection import train_test_split
from skorch.helper import predefined_split
from skorch.dataset import Dataset

X_train, X_test, y_train, y_test = train_test_split(X, y)

valid_ds = Dataset(X_test, y_test)

net = NeuralNet(
    ...,
    train_split=predefined_split(valid_ds)
)

net.fit(X_train, y_train)

API Reference

If you are looking for information on a specific function, class or method, this part of the documentation is for you.

skorch

skorch.callbacks

This module serves to elevate callbacks in submodules to the skorch.callback namespace. Remember to define __all__ in each submodule.

class skorch.callbacks.Callback[source]

Base class for callbacks.

All custom callbacks should inherit from this class. The subclass may override any of the on_... methods. It is, however, not necessary to override all of them, since it’s okay if they don’t have any effect.

Classes that inherit from this also gain the get_params and set_params method.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_begin(net, Xi=None, yi=None, training=None, **kwargs)[source]

Called at the beginning of each batch.

on_batch_end(net, Xi=None, yi=None, training=None, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the end of each epoch.

on_grad_computed(net, named_parameters, Xi=None, yi=None, training=None, **kwargs)[source]

Called once per batch after gradients have been computed but before an update step was performed.

on_train_begin(net, X=None, y=None, **kwargs)[source]

Called at the beginning of training.

on_train_end(net, X=None, y=None, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.EpochTimer(**kwargs)[source]

Measures the duration of each epoch and writes it to the history with the name dur.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
on_epoch_begin(net, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.PrintLog(keys_ignored=None, sink=<built-in function print>, tablefmt='simple', floatfmt='.4f', stralign='right')[source]

Print useful information from the model’s history as a table.

By default, PrintLog prints everything from the history except for 'batches'.

To determine the best loss, PrintLog looks for keys that end on '_best' and associates them with the corresponding loss. E.g., 'train_loss_best' will be matched with 'train_loss'. The Scoring callback takes care of creating those entries, which is why PrintLog works best in conjunction with that callback.

PrintLog treats keys with the 'event_' prefix in a special way. They are assumed to contain information about occasionally occuring events. The False or None entries (indicating that an event did not occur) are not printed, resulting in empty cells in the table, and True entries are printed with + symbol. PrintLog groups all event columns together and pushes them to the right, just before the 'dur' column.

Note: PrintLog will not result in good outputs if the number of columns varies between epochs, e.g. if the valid loss is only present on every other epoch.

Parameters:
keys_ignored : str or list of str (default=None)

Key or list of keys that should not be part of the printed table. Note that keys ending on ‘_best’ are also ignored.

sink : callable (default=print)

The target that the output string is sent to. By default, the output is printed to stdout, but the sink could also be a logger, etc.

tablefmt : str (default=’simple’)

The format of the table. See the documentation of the tabulate package for more detail. Can be ‘plain’, ‘grid’, ‘pipe’, ‘html’, ‘latex’, among others.

floatfmt : str (default=’.4f’)

The number formatting. See the documentation of the tabulate package for more details.

stralign : str (default=’right’)

The alignment of columns with strings. Can be ‘left’, ‘center’, ‘right’, or None (disable alignment). Default is ‘right’ (to be consistent with numerical columns).

Methods

format_row(row, key, color) For a given row from the table, format it (i.e.
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
table  
format_row(row, key, color)[source]

For a given row from the table, format it (i.e. floating points and color if applicable).

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.ProgressBar(batches_per_epoch='auto', detect_notebook=True, postfix_keys=None)[source]

Display a progress bar for each epoch.

The progress bar includes elapsed and estimated remaining time for the current epoch, the number of batches processed, and other user-defined metrics. The progress bar is erased once the epoch is completed.

ProgressBar needs to know the total number of batches per epoch in order to display a meaningful progress bar. By default, this number is determined automatically using the dataset length and the batch size. If this heuristic does not work for some reason, you may either specify the number of batches explicitly or let the ProgressBar count the actual number of batches in the previous epoch.

For jupyter notebooks a non-ASCII progress bar can be printed instead. To use this feature, you need to have ipywidgets installed.

Parameters:
batches_per_epoch : int, str (default=’auto’)

Either a concrete number or a string specifying the method used to determine the number of batches per epoch automatically. 'auto' means that the number is computed from the length of the dataset and the batch size. 'count' means that the number is determined by counting the batches in the previous epoch. Note that this will leave you without a progress bar at the first epoch.

detect_notebook : bool (default=True)

If enabled, the progress bar determines if its current environment is a jupyter notebook and switches to a non-ASCII progress bar.

postfix_keys : list of str (default=[‘train_loss’, ‘valid_loss’])

You can use this list to specify additional info displayed in the progress bar such as metrics and losses. A prerequisite to this is that these values are residing in the history on batch level already, i.e. they must be accessible via

>>> net.history[-1, 'batches', -1, key]

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
in_ipynb  
set_params  
on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train=None, dataset_valid=None, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.LRScheduler(policy='WarmRestartLR', monitor='train_loss', **kwargs)[source]

Callback that sets the learning rate of each parameter group according to some policy.

Parameters:
policy : str or _LRScheduler class (default=’WarmRestartLR’)

Learning rate policy name or scheduler to be used.

monitor : str or callable (default=None)

Value of the history to monitor or function/callable. In the latter case, the callable receives the net instance as argument and is expected to return the score (float) used to determine the learning rate adjustment.

kwargs

Additional arguments passed to the lr scheduler.

Attributes:
kwargs

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net, **kwargs) Called at the beginning of each batch.
on_batch_end(net, **kwargs) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, **kwargs) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
simulate(steps, initial_lr) Simulates the learning rate scheduler.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_begin(net, **kwargs)[source]

Called at the beginning of each batch.

on_batch_end(net, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, **kwargs)[source]

Called at the beginning of each epoch.

on_train_begin(net, **kwargs)[source]

Called at the beginning of training.

simulate(steps, initial_lr)[source]

Simulates the learning rate scheduler.

Parameters:
steps: int

Number of steps to simulate

initial_lr: float

Initial learning rate

Returns:
lrs: numpy ndarray

Simulated learning rates

class skorch.callbacks.WarmRestartLR(optimizer, min_lr=1e-06, max_lr=0.05, base_period=10, period_mult=2, last_epoch=-1)[source]

Stochastic Gradient Descent with Warm Restarts (SGDR) scheduler.

This scheduler sets the learning rate of each parameter group according to stochastic gradient descent with warm restarts (SGDR) policy. This policy simulates periodic warm restarts of SGD, where in each restart the learning rate is initialize to some value and is scheduled to decrease.

Parameters:
optimizer : torch.optimizer.Optimizer instance.

Optimizer algorithm.

min_lr : float or list of float (default=1e-6)

Minimum allowed learning rate during each period for all param groups (float) or each group (list).

max_lr : float or list of float (default=0.05)

Maximum allowed learning rate during each period for all param groups (float) or each group (list).

base_period : int (default=10)

Initial restart period to be multiplied at each restart.

period_mult : int (default=2)

Multiplicative factor to increase the period between restarts.

last_epoch : int (default=-1)

The index of the last valid epoch.

References

[1]Ilya Loshchilov and Frank Hutter, 2017, “Stochastic Gradient Descent with Warm Restarts,”. “ICLR” https://arxiv.org/pdf/1608.03983.pdf

Methods

load_state_dict(state_dict) Loads the schedulers state.
state_dict() Returns the state of the scheduler as a dict.
get_lr  
step  
class skorch.callbacks.CyclicLR(optimizer, base_lr=0.001, max_lr=0.006, step_size_up=2000, step_size_down=None, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', last_batch_idx=-1, step_size=None)[source]

Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). The policy cycles the learning rate between two boundaries with a constant frequency, as detailed in the paper. The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.

Cyclical learning rate policy changes the learning rate after every batch. batch_step should be called after a batch has been used for training. To resume training, save last_batch_idx and use it to instantiate CycleLR.

This class has three built-in policies, as put forth in the paper:

“triangular”:
A basic triangular cycle w/ no amplitude scaling.
“triangular2”:
A basic triangular cycle that scales initial amplitude by half each cycle.
“exp_range”:
A cycle that scales initial amplitude by gamma**(cycle iterations) at each cycle iteration.

This implementation was adapted from the github repo: bckenstler/CLR

Parameters:
optimizer : torch.optimizer.Optimizer instance.

Optimizer algorithm.

base_lr : float or list of float (default=1e-3)

Initial learning rate which is the lower boundary in the cycle for each param groups (float) or each group (list).

max_lr : float or list of float (default=6e-3)

Upper boundaries in the cycle for each parameter group (float) or each group (list). Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function.

step_size_up : int (default=2000)

Number of training iterations in the increasing half of a cycle.

step_size_down : int (default=None)

Number of training iterations in the decreasing half of a cycle. If step_size_down is None, it is set to step_size_up.

mode : str (default=’triangular’)

One of {triangular, triangular2, exp_range}. Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored.

gamma : float (default=1.0)

Constant in ‘exp_range’ scaling function: gamma**(cycle iterations)

scale_fn : function (default=None)

Custom scaling policy defined by a single argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0. mode paramater is ignored.

scale_mode : str (default=’cycle’)

One of {‘cycle’, ‘iterations’}. Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle).

last_batch_idx : int (default=-1)

The index of the last batch.

References

[1]Leslie N. Smith, 2017, “Cyclical Learning Rates for Training Neural Networks,”. “ICLR” https://arxiv.org/abs/1506.01186

Examples

>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.CyclicLR(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>>     for batch in data_loader:
>>>         scheduler.batch_step()
>>>         train_batch(...)

Methods

batch_step([batch_idx]) Updates the learning rate for the batch index: batch_idx.
get_lr() Calculates the learning rate at batch index: self.last_batch_idx.
step([epoch]) Not used by CyclicLR, use batch_step instead.
batch_step(batch_idx=None)[source]

Updates the learning rate for the batch index: batch_idx. If batch_idx is None, CyclicLR will use an internal batch index to keep track of the index.

get_lr()[source]

Calculates the learning rate at batch index: self.last_batch_idx.

step(epoch=None)[source]

Not used by CyclicLR, use batch_step instead.

class skorch.callbacks.GradientNormClipping(gradient_clip_value=None, gradient_clip_norm_type=2)[source]

Clips gradient norm of a module’s parameters.

The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.

See torch.nn.utils.clip_grad_norm_() for more information.

Parameters:
gradient_clip_value : float (default=None)

If not None, clip the norm of all model parameter gradients to this value. The type of the norm is determined by the gradient_clip_norm_type parameter and defaults to L2.

gradient_clip_norm_type : float (default=2)

Norm to use when gradient clipping is active. The default is to use L2-norm. Can be ‘inf’ for infinity norm.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(_, named_parameters, **kwargs) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
on_grad_computed(_, named_parameters, **kwargs)[source]

Called once per batch after gradients have been computed but before an update step was performed.

class skorch.callbacks.BatchScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]

Callback that performs generic scoring on batches.

This callback determines the score after each batch and stores it in the net’s history in the column given by name. At the end of the epoch, the average of the scores are determined and also stored in the history. Furthermore, it is determined whether this average score is the best score yet and that information is also stored in the history.

In contrast to EpochScoring, this callback determines the score for each batch and then averages the score at the end of the epoch. This can be disadvantageous for some scores if the batch size is small – e.g. area under the ROC will return incorrect scores in this case. Therefore, it is recommnded to use EpochScoring unless you really need the scores for each batch.

If y is None, the scoring function with signature (model, X, y) must be able to handle X as a Tensor and y=None.

Parameters:
scoring : None, str, or callable

If None, use the score method of the model. If str, it should be a valid sklearn metric (e.g. “f1_score”, “accuracy_score”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

lower_is_better : bool (default=True)

Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better

on_train : bool (default=False)

Whether this should be called during train or validation.

name : str or None (default=None)

If not an explicit string, tries to infer the name from the scoring argument.

target_extractor : callable (default=to_numpy)

This is called on y before it is passed to scoring.

use_caching : bool (default=True)

Re-use the model’s prediction for computing the loss to calculate the score. Turning this off will result in an additional inference step for each batch.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net, X, y, training, **kwargs) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, X, y, **kwargs) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_avg_score  
get_params  
set_params  
on_batch_end(net, X, y, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

class skorch.callbacks.EpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]

Callback that performs generic scoring on predictions.

At the end of each epoch, this callback makes a prediction on train or validation data, determines the score for that prediction and whether it is the best yet, and stores the result in the net’s history.

In case you already computed a score value for each batch you can omit the score computation step by return the value from the history. For example:

>>> def my_score(net, X=None, y=None):
...     losses = net.history[-1, 'batches', :, 'my_score']
...     batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size']
...     return np.average(losses, weights=batch_sizes)
>>> net = MyNet(callbacks=[
...     ('my_score', Scoring(my_score, name='my_score'))

If you fit with a custom dataset, this callback should work as expected as long as use_caching=True which enables the collection of y values from the dataset. If you decide to disable the caching of predictions and y values, you need to write your own scoring function that is able to deal with the dataset and returns a scalar, for example:

>>> def ds_accuracy(net, ds, y=None):
...     # assume ds yields (X, y), e.g. torchvision.datasets.MNIST
...     y_true = [y for _, y in ds]
...     y_pred = net.predict(ds)
...     return sklearn.metrics.accuracy_score(y_true, y_pred)
>>> net = MyNet(callbacks=[
...     EpochScoring(ds_accuracy, use_caching=False)])
>>> ds = torchvision.datasets.MNIST(root=mnist_path)
>>> net.fit(ds)
Parameters:
scoring : None, str, or callable (default=None)

If None, use the score method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to the scoring parameter in sklearn’s GridSearchCV et al.

lower_is_better : bool (default=True)

Whether lower scores should be considered better or worse.

on_train : bool (default=False)

Whether this should be called during train or validation data.

name : str or None (default=None)

If not an explicit string, tries to infer the name from the scoring argument.

target_extractor : callable (default=to_numpy)

This is called on y before it is passed to scoring.

use_caching : bool (default=True)

Collect labels and predictions (y_true and y_pred) over the course of one epoch and use the cached values for computing the score. The cached values are shared between all EpochScoring instances. Disabling this will result in an additional inference step for each epoch and an inability to use arbitrary datasets as input (since we don’t know how to extract y_true from an arbitrary dataset).

Methods

get_test_data(dataset_train, dataset_valid) Return data needed to perform scoring.
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net, y, y_pred, training, **kwargs) Called at the end of each batch.
on_epoch_begin(net, dataset_train, …) Called at the beginning of each epoch.
on_epoch_end(net, dataset_train, …) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, X, y, **kwargs) Called at the beginning of training.
on_train_end(*args, **kwargs) Called at the end of training.
get_params  
set_params  
get_test_data(dataset_train, dataset_valid)[source]

Return data needed to perform scoring.

This is a convenience method that handles picking of train/valid, different types of input data, use of cache, etc. for you.

Parameters:
dataset_train

Incoming training data or dataset.

dataset_valid

Incoming validation data or dataset.

Returns:
X_test

Input data used for making the prediction.

y_test

Target ground truth. If caching was enabled, return cached y_test.

y_pred : list

The predicted targets. If caching was disabled, the list is empty. If caching was enabled, the list contains the batches of the predictions. It may thus be necessary to concatenate the output before working with it: y_pred = np.concatenate(y_pred)

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_batch_end(net, y, y_pred, training, **kwargs)[source]

Called at the end of each batch.

on_epoch_begin(net, dataset_train, dataset_valid, **kwargs)[source]

Called at the beginning of each epoch.

on_epoch_end(net, dataset_train, dataset_valid, **kwargs)[source]

Called at the end of each epoch.

on_train_end(*args, **kwargs)[source]

Called at the end of training.

class skorch.callbacks.Checkpoint(target=None, monitor='valid_loss_best', f_params='params.pt', f_optimizer='optimizer.pt', f_history='history.json', f_pickle=None, fn_prefix='', dirname='', event_name='event_cp', sink=<function noop>)[source]

Save the model during training if the given metric improved.

This callback works by default in conjunction with the validation scoring callback since it creates a valid_loss_best value in the history which the callback uses to determine if this epoch is save-worthy.

You can also specify your own metric to monitor or supply a callback that dynamically evaluates whether the model should be saved in this epoch.

Some or all of the following can be saved:

  • model parameters (see f_params parameter);
  • optimizer state (see f_optimizer parameter);
  • training history (see f_history parameter);
  • entire model object (see f_pickle parameter).

You can implement your own save protocol by subclassing Checkpoint and overriding save_model().

This callback writes a bool flag to the history column event_cp indicating whether a checkpoint was created or not.

Example:

>>> net = MyNet(callbacks=[Checkpoint()])
>>> net.fit(X, y)

Example using a custom monitor where models are saved only in epochs where the validation and the train losses are best:

>>> monitor = lambda net: all(net.history[-1, (
...     'train_loss_best', 'valid_loss_best')])
>>> net = MyNet(callbacks=[Checkpoint(monitor=monitor)])
>>> net.fit(X, y)
Parameters:
target : deprecated
monitor : str, function, None

Value of the history to monitor or callback that determines whether this epoch should lead to a checkpoint. The callback takes the network instance as parameter.

In case monitor is set to None, the callback will save the network at every epoch.

Note: If you supply a lambda expression as monitor, you cannot pickle the wrapper anymore as lambdas cannot be pickled. You can mitigate this problem by using importable functions instead.

f_params : file-like object, str, None (default=’params.pt’)

File path to the file or file-like object where the model parameters should be saved. Pass None to disable saving model parameters.

If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are net, last_epoch and last_batch. Example to include last epoch number in file name:

>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
f_optimizer : file-like object, str, None (default=’optimizer.pt’)

File path to the file or file-like object where the optimizer state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_history : file-like object, str, None (default=’history.json’)

File path to the file or file-like object where the model training history should be saved. Pass None to disable saving history.

f_pickle : file-like object, str, None (default=None)

File path to the file or file-like object where the entire model object should be pickled. Pass None to disable pickling.

Supports the same format specifiers as f_params.

fn_prefix: str (default=’‘)

Prefix for filenames. If f_params, f_optimizer, f_history, or f_pickle are strings, they will be prefixed by fn_prefix.

dirname: str (default=’‘)

Directory where files are stored.

event_name: str, (default=’event_cp’)

Name of event to be placed in history when checkpoint is triggered. Pass None to disable placing events in history.

sink : callable (default=noop)

The target that the information about created checkpoints is sent to. This can be a logger or print function (to send to stdout). By default the output is discarded.

Attributes:
f_history_

Methods

get_formatted_files(net) Returns a dictionary of formatted filenames
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
save_model(net) Save the model.
get_params  
set_params  
get_formatted_files(net)[source]

Returns a dictionary of formatted filenames

initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

save_model(net)[source]

Save the model.

This function saves some or all of the following:

  • model parameters;
  • optimizer state;
  • training history;
  • entire model object.
class skorch.callbacks.EarlyStopping(monitor='valid_loss', patience=5, threshold=0.0001, threshold_mode='rel', lower_is_better=True, sink=<built-in function print>)[source]

Callback for stopping training when scores don’t improve.

Stop training early if a specified monitor metric did not improve in patience number of epochs by at least threshold.

Parameters:
monitor : str (default=’valid_loss’)

Value of the history to monitor to decide whether to stop training or not. The value is expected to be double and is commonly provided by scoring callbacks such as skorch.callbacks.EpochScoring.

lower_is_better : bool (default=True)

Whether lower scores should be considered better or worse.

patience : int (default=5)

Number of epochs to wait for improvement of the monitor value until the training process is stopped.

threshold : int (default=1e-4)

Ignore score improvements smaller than threshold.

threshold_mode : str (default=’rel’)

One of rel, abs. Decides whether the threshold value is interpreted in absolute terms or as a fraction of the best score so far (relative)

sink : callable (default=print)

The target that the information about early stopping is sent to. By default, the output is printed to stdout, but the sink could also be a logger or noop().

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net, **kwargs) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_begin(net, **kwargs)[source]

Called at the beginning of training.

class skorch.callbacks.Freezer(*args, **kwargs)[source]

Freeze matching parameters at the start of the first epoch. You may specify a specific point in time (either by epoch number or using a callable) when the parameters are frozen using the at parameter.

See ParamMapper for details.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
class skorch.callbacks.Unfreezer(*args, **kwargs)[source]

Inverse operation of Freezer.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
class skorch.callbacks.Initializer(*args, **kwargs)[source]

Apply any function on matching parameters in the first epoch.

Examples

Use Initializer to initialize all dense layer weights with values sampled from an uniform distribution on the beginning of the first epoch:

>>> init_fn = partial(torch.nn.init.uniform_, a=-1e-3, b=1e-3)
>>> cb = Initializer('dense*.weight', fn=init_fn)
>>> net = Net(myModule, callbacks=[cb])

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
class skorch.callbacks.ParamMapper(patterns, fn=<function noop>, at=1, schedule=None)[source]

Map arbitrary functions over module parameters filtered by pattern matching.

In the simplest case the function is only applied once at the beginning of a given epoch (at on_epoch_begin) but more complex execution schemes (e.g. periodic application) are possible using at and scheduler.

Parameters:
patterns : str or callable or list

The pattern(s) to match parameter names against. Patterns are UNIX globbing patterns as understood by fnmatch(). Patterns can also be callables which will get called with the parameter name and are regarded as a match when the callable returns a truthy value.

This parameter also supports lists of str or callables so that one ParamMapper can match a group of parameters.

Example: 'linear*.weight' or ['linear0.*', 'linear1.bias'] or lambda name: name.startswith('linear').

fn : function

The function to apply to each parameter separately.

at : int or callable

In case you specify an integer it represents the epoch number the function fn is applied to the parameters, in case at is a function it will receive net as parameter and the function is applied to the parameter once at returns True.

schedule : callable or None

If specified this callable supersedes the static at/fn combination by dynamically returning the function that is applied on the matched parameters. This way you can, for example, create a schedule that periodically freezes and unfreezes layers.

The callable’s signature is schedule(net: NeuralNet) -> callable.

Notes

When starting the training process after saving and loading a model, ParamMapper might re-initialize parts of your model when the history is not saved along with the model. To avoid this, in case you use ParamMapper (or subclasses, e.g. Initializer) and want to save your model make sure to either (a) use pickle, (b) save and load the history or (c) remove the parameter mapper callbacks before continuing training.

Examples

Initialize a layer on first epoch before the first training step:

>>> init = partial(torch.nn.init.uniform_, a=0, b=1)
>>> cb = ParamMapper('linear*.weight', at=1, fn=init)
>>> net = Net(myModule, callbacks=[cb])

Reset layer initialization if train loss reaches a certain value (e.g. re-initialize on overfit):

>>> at = lambda net: net.history[-1, 'train_loss'] < 0.1
>>> init = partial(torch.nn.init.uniform_, a=0, b=1)
>>> cb = ParamMapper('linear0.weight', at=at, fn=init)
>>> net = Net(myModule, callbacks=[cb])

Periodically freeze and unfreeze all embedding layers:

>>> def my_sched(net):
...    if len(net.history) % 2 == 0:
...        return skorch.utils.freeze_parameter
...    else:
...        return skorch.utils.unfreeze_parameter
>>> cb = ParamMapper('embedding*.weight', schedule=my_sched)
>>> net = Net(myModule, callbacks=[cb])

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net, **kwargs) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
filter_parameters  
get_params  
named_parameters  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_epoch_begin(net, **kwargs)[source]

Called at the beginning of each epoch.

class skorch.callbacks.LoadInitState(checkpoint)[source]

Loads the model, optimizer, and history from a checkpoint into a NeuralNet when training begins.

Parameters:
checkpoint: :class:`.Checkpoint`

Checkpoint to get filenames from.

Examples

Consider running the following example multiple times:

>>> cp = Checkpoint(monitor='valid_loss_best')
>>> load_state = LoadInitState(cp)
>>> net = NeuralNet(..., callbacks=[cp, load_state])
>>> net.fit(X, y)

On the first run, the Checkpoint saves the model, optimizer, and history when the validation loss is minimized. During the first run, there are no files on disk, thus LoadInitState will not load anything. When running the example a second time, LoadInitState will load the best model from the first run and continue training from there.

Methods

initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net[, dataset_train, dataset_valid]) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net[, X, y]) Called at the end of training.
get_params  
set_params  
initialize()[source]

(Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.

This method should return self.

on_train_begin(net, X=None, y=None, **kwargs)[source]

Called at the beginning of training.

class skorch.callbacks.TrainEndCheckpoint(f_params='params.pt', f_optimizer='optimizer.pt', f_history='history.json', f_pickle=None, fn_prefix='final_', dirname='', sink=<function noop>)[source]

Saves the model parameters, optimizer state, and history at the end of training. The default fn_prefix is ‘final_’.

Parameters:
f_params : file-like object, str, None (default=’params.pt’)

File path to the file or file-like object where the model parameters should be saved. Pass None to disable saving model parameters.

If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are net, last_epoch and last_batch. Example to include last epoch number in file name:

>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
f_optimizer : file-like object, str, None (default=’optimizer.pt’)

File path to the file or file-like object where the optimizer state should be saved. Pass None to disable saving model parameters.

Supports the same format specifiers as f_params.

f_history : file-like object, str, None (default=’history.json’)

File path to the file or file-like object where the model training history should be saved. Pass None to disable saving history.

f_pickle : file-like object, str, None (default=None)

File path to the file or file-like object where the entire model object should be pickled. Pass None to disable pickling.

Supports the same format specifiers as f_params.

fn_prefix: str (default=’final_’)

Prefix for filenames. If f_params, f_optimizer, f_history, or f_pickle are strings, they will be prefixed by fn_prefix.

dirname: str (default=’‘)

Directory where files are stored.

sink : callable (default=noop)

The target that the information about created checkpoints is sent to. This can be a logger or print function (to send to stdout). By default the output is discarded.

Examples

Consider running the following example multiple times:

>>> final_cp = TrainEndCheckpoint(dirname='exp1')
>>> load_state = LoadInitState(final_cp)
>>> net = NeuralNet(..., callbacks=[final_cp, load_state])
>>> net.fit(X, y)

After the first run, model parameters, optimizer state, and history are saved into a directory named exp1. On the next run, LoadInitState will load the state from the first run and continue training.

Attributes:
f_history_

Methods

get_formatted_files(net) Returns a dictionary of formatted filenames
initialize() (Re-)Set the initial state of the callback.
on_batch_begin(net[, Xi, yi, training]) Called at the beginning of each batch.
on_batch_end(net[, Xi, yi, training]) Called at the end of each batch.
on_epoch_begin(net[, dataset_train, …]) Called at the beginning of each epoch.
on_epoch_end(net, **kwargs) Called at the end of each epoch.
on_grad_computed(net, named_parameters[, …]) Called once per batch after gradients have been computed but before an update step was performed.
on_train_begin(net[, X, y]) Called at the beginning of training.
on_train_end(net, **kwargs) Called at the end of training.
save_model(net) Save the model.
get_params  
set_params  
on_epoch_end(net, **kwargs)[source]

Called at the end of each epoch.

on_train_end(net, **kwargs)[source]

Called at the end of training.

skorch.classifier

NeuralNet subclasses for classification tasks.

class skorch.classifier.NeuralNetBinaryClassifier(module, *args, criterion=<class 'torch.nn.modules.loss.BCEWithLogitsLoss'>, train_split=<skorch.dataset.CVSplit object>, threshold=0.5, **kwargs)[source]

NeuralNet for binary classification tasks

Use this specifically if you have a binary classification task, with input data X and target y. y must be 1d.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class, default=torch.nn.BCEWithLogitsLoss)

Binary cross entropy loss with logits. Note that the module should return the logit of probabilities with shape (batch_size, ).

threshold : float (default=0.5)

Probabilities above this threshold is classified as 1. threshold is used by predict and predict_proba for classification.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None or list of Callback instances (default=None)

More callbacks, in addition to those returned by get_default_callbacks. Each callback should inherit from Callback. If not None, a list of callbacks is expected where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

Control the verbosity level.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attributes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

check_data(X, y)
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X, y, **fit_params) See NeuralNet.fit.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training]) Return the loss for this batch.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer() Initialize the model optimizer.
load_history(f) Load the history of a NeuralNet from a json file.
load_params([f, f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Where applicable, return probability estimates for samples.
save_history(f) Saves the history of NeuralNet as a json file.
save_params([f, f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
get_default_callbacks  
get_params  
on_batch_end  
on_grad_computed  
fit(X, y, **fit_params)[source]

See NeuralNet.fit.

In contrast to NeuralNet.fit, y is non-optional to avoid mistakenly forgetting about y. However, y can be set to None in case it is derived dynamically from X.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_pred : numpy ndarray
predict_proba(X)[source]

Where applicable, return probability estimates for samples.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_proba : numpy ndarray
class skorch.classifier.NeuralNetClassifier(module, *args, criterion=<class 'torch.nn.modules.loss.NLLLoss'>, train_split=<skorch.dataset.CVSplit object>, **kwargs)[source]

NeuralNet for classification tasks

Use this specifically if you have a standard classification task, with input data X and target y.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class, default=torch.nn.NLLLoss)

Negative log likelihood loss. Note that the module should return probabilities, the log is applied during get_loss.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None or list of Callback instances (default=None)

More callbacks, in addition to those returned by get_default_callbacks. Each callback should inherit from Callback. If not None, a list of callbacks is expected where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

Control the verbosity level.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attributes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

check_data(X, y)
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X, y, **fit_params) See NeuralNet.fit.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true, *args, **kwargs) Return the loss for this batch.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer() Initialize the model optimizer.
load_history(f) Load the history of a NeuralNet from a json file.
load_params([f, f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Where applicable, return probability estimates for samples.
save_history(f) Saves the history of NeuralNet as a json file.
save_params([f, f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
get_default_callbacks  
get_params  
on_batch_end  
on_grad_computed  
fit(X, y, **fit_params)[source]

See NeuralNet.fit.

In contrast to NeuralNet.fit, y is non-optional to avoid mistakenly forgetting about y. However, y can be set to None in case it is derived dynamically from X.

get_loss(y_pred, y_true, *args, **kwargs)[source]

Return the loss for this batch.

Parameters:
y_pred : torch tensor

Predicted target values

y_true : torch tensor

True target values.

X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether train mode should be used or not.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_pred : numpy ndarray
predict_proba(X)[source]

Where applicable, return probability estimates for samples.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_proba : numpy ndarray

skorch.dataset

Contains custom skorch Dataset and CVSplit.

class skorch.dataset.CVSplit(cv=5, stratified=False, random_state=None)[source]

Class that performs the internal train/valid split on a dataset.

The cv argument here works similarly to the regular sklearn cv parameter in, e.g., GridSearchCV. However, instead of cycling through all splits, only one fixed split (the first one) is used. To get a full cycle through the splits, don’t use NeuralNet’s internal validation but instead the corresponding sklearn functions (e.g. cross_val_score).

We additionally support a float, similar to sklearn’s train_test_split.

Parameters:
cv : int, float, cross-validation generator or an iterable, optional

(Refer sklearn’s User Guide for cross_validation for the various cross-validation strategies that can be used here.)

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 3-fold cross validation,
  • integer, to specify the number of folds in a (Stratified)KFold,
  • float, to represent the proportion of the dataset to include in the validation split.
  • An object to be used as a cross-validation generator.
  • An iterable yielding train, validation splits.
stratified : bool (default=False)

Whether the split should be stratified. Only works if y is either binary or multiclass classification.

random_state : int, RandomState instance, or None (default=None)

Control the random state in case that (Stratified)ShuffleSplit is used (which is when a float is passed to cv). For more information, look at the sklearn documentation of (Stratified)ShuffleSplit.

Methods

__call__(dataset[, y, groups]) Call self as a function.
check_cv(y) Resolve which cross validation strategy is used.
check_cv(y)[source]

Resolve which cross validation strategy is used.

class skorch.dataset.Dataset(X, y=None, device=None, length=None)[source]

General dataset wrapper that can be used in conjunction with PyTorch DataLoader.

The dataset will always yield a tuple of two values, the first from the data (X) and the second from the target (y). However, the target is allowed to be None. In that case, Dataset will currently return a dummy tensor, since DataLoader does not work with Nones.

Dataset currently works with the following data types:

  • numpy arrays
  • PyTorch Tensors
  • pandas NDFrame
  • a dictionary of the former three
  • a list/tuple of the former three
Parameters:
X : see above

Everything pertaining to the input data.

y : see above or None (default=None)

Everything pertaining to the target, if there is anything.

length : int or None (default=None)

If not None, determines the length (len) of the data. Should usually be left at None, in which case the length is determined by the data itself.

Methods

transform(X, y) Additional transformations on X and y.
transform(X, y)[source]

Additional transformations on X and y.

By default, they are cast to PyTorch Tensors. Override this if you want a different behavior.

Note: If you use this in conjuction with PyTorch DataLoader, the latter will call the dataset for each row separately, which means that the incoming X and y each are single rows.

skorch.dataset.uses_placeholder_y(ds)[source]

If ds is a skorch.dataset.Dataset or a skorch.dataset.Dataset nested inside a torch.utils.data.Subset and uses y as a placeholder, return True.

skorch.exceptions

Contains skorch-specific exceptions and warnings.

exception skorch.exceptions.DeviceWarning[source]

A problem with a device (e.g. CUDA) was detected.

exception skorch.exceptions.NotInitializedError[source]

Module is not initialized, please call the .initialize method or train the model by calling .fit(...).

exception skorch.exceptions.SkorchException[source]

Base skorch exception.

exception skorch.exceptions.SkorchWarning[source]

Base skorch warning.

skorch.helper

Helper functions and classes for users.

They should not be used in skorch directly.

class skorch.helper.SliceDict(**kwargs)[source]

Wrapper for Python dict that makes it sliceable across values.

Use this if your input data is a dictionary and you have problems with sklearn not being able to slice it. Wrap your dict with SliceDict and it should usually work.

Note: SliceDict cannot be indexed by integers, if you want one row, say row 3, use [3:4].

Examples

>>> X = {'key0': val0, 'key1': val1}
>>> search = GridSearchCV(net, params, ...)
>>> search.fit(X, y)  # raises error
>>> Xs = SliceDict(key0=val0, key1=val1)  # or Xs = SliceDict(**X)
>>> search.fit(Xs, y)  # works
Attributes:
shape

Methods

clear()
copy()
fromkeys($type, iterable[, value]) Returns a new dict with keys from iterable and values equal to value.
get(k[,d])
items()
keys()
pop(k[,d]) If key is not found, d is returned if given, otherwise KeyError is raised
popitem() 2-tuple; but raise KeyError if D is empty.
setdefault(k[,d])
update([E, ]**F) If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
values()
update([E, ]**F) → None. Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

skorch.helper.filter_requires_grad(pgroups)[source]

Returns parameter groups where parameters that don’t require a gradient are filtered out.

Parameters:
pgroups : dict

Parameter groups to be filtered

skorch.helper.filtered_optimizer(optimizer, filter_fn)[source]

Wraps an optimizer that filters out parameters where filter_fn over pgroups returns False. This function can be used, for example, to filter parameters that do not require a gradient:

>>> from skorch.helper import filtered_optimizer, filter_requires_grad
>>> optimizer = filtered_optimizer(torch.optim.SGD, filter_requires_grad)
>>> net = NeuralNetClassifier(module, optimizer=optimizer)
Parameters:
optimizer : torch optim (class)

The uninitialized optimizer that is wrapped

filter_fn : function

Use this function to filter parameter groups before passing it to optimizer.

skorch.helper.predefined_split(dataset)[source]

Uses dataset for validiation in NeutralNet.

Parameters:
dataset: torch Dataset

Validiation dataset

Examples

>>> valid_ds = skorch.Dataset(X, y)
>>> net = NeutralNet(..., train_split=predefined_split(valid_ds))

skorch.history

Contains history class and helper functions.

class skorch.history.History[source]

History contains the information about the training history of a NeuralNet, facilitating some of the more common tasks that are occur during training.

When you want to log certain information during training (say, a particular score or the norm of the gradients), you should write them to the net’s history object.

It is basically a list of dicts for each epoch, that, again, contains a list of dicts for each batch. For convenience, it has enhanced slicing notation and some methods to write new items.

To access items from history, you may pass a tuple of up to four items:

  1. Slices along the epochs.
  2. Selects columns from history epochs, may be a single one or a tuple of column names.
  3. Slices along the batches.
  4. Selects columns from history batchs, may be a single one or a tuple of column names.

You may use a combination of the four items.

If you select columns that are not present in all epochs/batches, only those epochs/batches are chosen that contain said columns. If this set is empty, a KeyError is raised.

Examples

>>> # ACCESSING ITEMS
>>> # history of a fitted neural net
>>> history = net.history
>>> # get current epoch, a dict
>>> history[-1]
>>> # get train losses from all epochs, a list of floats
>>> history[:, 'train_loss']
>>> # get train and valid losses from all epochs, a list of tuples
>>> history[:, ('train_loss', 'valid_loss')]
>>> # get current batches, a list of dicts
>>> history[-1, 'batches']
>>> # get latest batch, a dict
>>> history[-1, 'batches', -1]
>>> # get train losses from current batch, a list of floats
>>> history[-1, 'batches', :, 'train_loss']
>>> # get train and valid losses from current batch, a list of tuples
>>> history[-1, 'batches', :, ('train_loss', 'valid_loss')]
>>> # WRITING ITEMS
>>> # add new epoch row
>>> history.new_epoch()
>>> # add an entry to current epoch
>>> history.record('my-score', 123)
>>> # add a batch row to the current epoch
>>> history.new_batch()
>>> # add an entry to the current batch
>>> history.record_batch('my-batch-score', 456)
>>> # overwrite entry of current batch
>>> history.record_batch('my-batch-score', 789)

Methods

append(object)
clear()
copy()
count(value)
extend(iterable)
from_file(f) Load the history of a NeuralNet from a json file.
index(value, [start, [stop]]) Raises ValueError if the value is not present.
insert L.insert(index, object) – insert object before index
new_batch() Register a new batch row for the current epoch.
new_epoch() Register a new epoch row.
pop([index]) Raises IndexError if list is empty or index is out of range.
record(attr, value) Add a new value to the given column for the current epoch.
record_batch(attr, value) Add a new value to the given column for the current batch.
remove(value) Raises ValueError if the value is not present.
reverse L.reverse() – reverse IN PLACE
sort([key, reverse])
to_file(f) Saves the history as a json file.
to_list() Return history object as a list.
classmethod from_file(f)[source]

Load the history of a NeuralNet from a json file.

Parameters:
f : file-like object or str
new_batch()[source]

Register a new batch row for the current epoch.

new_epoch()[source]

Register a new epoch row.

record(attr, value)[source]

Add a new value to the given column for the current epoch.

record_batch(attr, value)[source]

Add a new value to the given column for the current batch.

to_file(f)[source]

Saves the history as a json file. In order to use this feature, the history must only contain JSON encodable Python data structures. Numpy and PyTorch types should not be in the history.

Parameters:
f : file-like object or str
to_list()[source]

Return history object as a list.

skorch.net

Neural net classes.

class skorch.net.NeuralNet(module, criterion, optimizer=<class 'torch.optim.sgd.SGD'>, lr=0.01, max_epochs=10, batch_size=128, iterator_train=<class 'torch.utils.data.dataloader.DataLoader'>, iterator_valid=<class 'torch.utils.data.dataloader.DataLoader'>, dataset=<class 'skorch.dataset.Dataset'>, train_split=<skorch.dataset.CVSplit object>, callbacks=None, warm_start=False, verbose=1, device='cpu', **kwargs)[source]

NeuralNet base class.

The base class covers more generic cases. Depending on your use case, you might want to use NeuralNetClassifier or NeuralNetRegressor.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class)

The uninitialized criterion (loss) used to optimize the module.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None or list of Callback instances (default=None)

More callbacks, in addition to those returned by get_default_callbacks. Each callback should inherit from Callback. If not None, a list of callbacks is expected where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

Control the verbosity level.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attributes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X[, y]) Initialize and fit the module.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training]) Return the loss for this batch.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer() Initialize the model optimizer.
load_history(f) Load the history of a NeuralNet from a json file.
load_params([f, f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Return the output of the module’s forward method as a numpy array.
save_history(f) Saves the history of NeuralNet as a json file.
save_params([f, f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
check_data  
get_default_callbacks  
get_params  
on_batch_end  
on_grad_computed  
evaluation_step(Xi, training=False)[source]

Perform a forward step to produce the output used for prediction and scoring.

Therefore the module is set to evaluation mode by default beforehand which can be overridden to re-enable features like dropout by setting training=True.

fit(X, y=None, **fit_params)[source]

Initialize and fit the module.

If the module was already initialized, by calling fit, the module will be re-initialized (unless warm_start is True).

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

fit_loop(X, y=None, epochs=None, **fit_params)[source]

The proper fit loop.

Contains the logic of what actually happens during the fit loop.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

epochs : int or None (default=None)

If int, train for this number of epochs; if None, use self.max_epochs.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

forward(X, training=False, device='cpu')[source]

Gather and concatenate the output from forward call with input data.

The outputs from self.module_.forward are gathered on the compute device specified by device and then concatenated using PyTorch cat(). If multiple outputs are returned by self.module_.forward, each one of them must be able to be concatenated this way.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether to set the module to train mode or not.

device : string (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Returns:
y_infer : torch tensor

The result from the forward step.

forward_iter(X, training=False, device='cpu')[source]

Yield outputs of module forward calls on each batch of data. The storage device of the yielded tensors is determined by the device parameter.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether to set the module to train mode or not.

device : string (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Yields:
yp : torch tensor

Result from a forward call on an individual batch.

get_dataset(X, y=None)[source]

Get a dataset that contains the input data and is passed to the iterator.

Override this if you want to initialize your dataset differently.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

Returns:
dataset

The initialized dataset.

get_iterator(dataset, training=False)[source]

Get an iterator that allows to loop over the batches of the given data.

If self.iterator_train__batch_size and/or self.iterator_test__batch_size are not set, use self.batch_size instead.

Parameters:
dataset : torch Dataset (default=skorch.dataset.Dataset)

Usually, self.dataset, initialized with the corresponding data, is passed to get_iterator.

training : bool (default=False)

Whether to use iterator_train or iterator_test.

Returns:
iterator

An instantiated iterator that allows to loop over the mini-batches.

get_loss(y_pred, y_true, X=None, training=False)[source]

Return the loss for this batch.

Parameters:
y_pred : torch tensor

Predicted target values

y_true : torch tensor

True target values.

X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether train mode should be used or not.

get_split_datasets(X, y=None, **fit_params)[source]

Get internal train and validation datasets.

The validation dataset can be None if self.train_split is set to None; then internal validation will be skipped.

Override this if you want to change how the net splits incoming data into train and validation part.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_params : dict

Additional parameters passed to the self.train_split call.

Returns:
dataset_train

The initialized training dataset.

dataset_valid

The initialized validation dataset or None

get_train_step_accumulator()[source]

Return the train step accumulator.

By default, the accumulator stores and retrieves the first value from the optimizer call. Most optimizers make only one call, so first value is at the same time the only value.

In case of some optimizers, e.g. LBFGS, train_step_calc_gradient is called multiple times, as the loss function is evaluated multiple times per optimizer call. If you don’t want to return the first value in that case, override this method to return your custom accumulator.

infer(x, **fit_params)[source]

Perform a single inference step on a batch of data.

Parameters:
x : input data

A batch of the input data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

initialize()[source]

Initializes all components of the NeuralNet and returns self.

initialize_callbacks()[source]

Initializes all callbacks and save the result in the callbacks_ attribute.

Both default_callbacks and callbacks are used (in that order). Callbacks may either be initialized or not, and if they don’t have a name, the name is inferred from the class name. The initialize method is called on all callbacks.

The final result will be a list of tuples, where each tuple consists of a name and an initialized callback. If names are not unique, a ValueError is raised.

initialize_criterion()[source]

Initializes the criterion.

initialize_history()[source]

Initializes the history.

initialize_module()[source]

Initializes the module.

Note that if the module has learned parameters, those will be reset.

initialize_optimizer()[source]

Initialize the model optimizer. If self.optimizer__lr is not set, use self.lr instead.

load_history(f)[source]

Load the history of a NeuralNet from a json file. See save_history for examples.

Parameters:
f : file-like object or str
load_params(f=None, f_params=None, f_optimizer=None, f_history=None, checkpoint=None)[source]

Loads the the module’s parameters, history, and optimizer, not the whole object.

To save and load the whole object, use pickle.

f_params and f_optimizer uses PyTorchs’ save().

Parameters:
f_params : file-like object, str, None (default=None)

Path of module parameters. Pass None to not load.

f_optimizer : file-like object, str, None (default=None)

Path of optimizer. Pass None to not load.

f_history : file-like object, str, None (default=None)

Path to history. Pass None to not load.

checkpoint : Checkpoint, None (default=None)

Checkpoint to load params from. If a checkpoint and a f_* path is passed in, the f_* will be loaded. Pass None to not load.

f : deprecated

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
>>>                    f_optimizer='optimizer.pkl',
>>>                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
>>>                   f_optimizer='optimizer.pkl',
>>>                   f_history='history.json')
notify(method_name, **cb_kwargs)[source]

Call the callback method specified in method_name with parameters specified in cb_kwargs.

Method names can be one of: * on_train_begin * on_train_end * on_epoch_begin * on_epoch_end * on_batch_begin * on_batch_end

partial_fit(X, y=None, classes=None, **fit_params)[source]

Fit the module.

If the module is initialized, it is not re-initialized, which means that this method should be used if you want to continue training a model (warm start).

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

classes : array, sahpe (n_classes,)

Solely for sklearn compatibility, currently unused.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_pred : numpy ndarray
predict_proba(X)[source]

Return the output of the module’s forward method as a numpy array.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_proba : numpy ndarray
save_history(f)[source]

Saves the history of NeuralNet as a json file. In order to use this feature, the history must only contain JSON encodable Python data structures. Numpy and PyTorch types should not be in the history.

Parameters:
f : file-like object or str

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.fit(X, y, epoch=2) # Train for 2 epochs
>>> before.save_params('path/to/params')
>>> before.save_history('path/to/history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params('path/to/params')
>>> after.load_history('path/to/history.json')
>>> after.fit(X, y, epoch=2) # Train for another 2 epochs
save_params(f=None, f_params=None, f_optimizer=None, f_history=None)[source]

Saves the module’s parameters, history, and optimizer, not the whole object.

To save the whole object, use pickle.

f_params and f_optimizer uses PyTorchs’ save().

Parameters:
f_params : file-like object, str, None (default=None)

Path of module parameters. Pass None to not save

f_optimizer : file-like object, str, None (default=None)

Path of optimizer. Pass None to not save

f_history : file-like object, str, None (default=None)

Path to history. Pass None to not save

f : deprecated

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
>>>                    f_optimizer='optimizer.pkl',
>>>                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
>>>                   f_optimizer='optimizer.pkl',
>>>                   f_history='history.json')
set_params(**kwargs)[source]

Set the parameters of this class.

Valid parameter keys can be listed with get_params().

Returns:
self
train_step(Xi, yi, **fit_params)[source]

Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.

Loss function callable as required by some optimizers (and accepted by all of them): https://pytorch.org/docs/master/optim.html#optimizer-step-closure

The module is set to be in train mode (e.g. dropout is applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the train_split call.

train_step_single(Xi, yi, **fit_params)[source]

Compute y_pred, loss value, and update net’s gradients.

The module is set to be in train mode (e.g. dropout is applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

validation_step(Xi, yi, **fit_params)[source]

Perform a forward step using batched data and return the resulting loss.

The module is set to be in evaluation mode (e.g. dropout is not applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

skorch.regressor

NeuralNet subclasses for regression tasks.

class skorch.regressor.NeuralNetRegressor(module, *args, criterion=<class 'torch.nn.modules.loss.MSELoss'>, **kwargs)[source]

NeuralNet for regression tasks

Use this specifically if you have a standard regression task, with input data X and target y. y must be 2d.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class, default=torch.nn.MSELoss)

Mean squared error loss.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None or list of Callback instances (default=None)

More callbacks, in addition to those returned by get_default_callbacks. Each callback should inherit from Callback. If not None, a list of callbacks is expected where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

Control the verbosity level.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attributes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

check_data(X, y)
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X, y, **fit_params) See NeuralNet.fit.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training]) Return the loss for this batch.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer() Initialize the model optimizer.
load_history(f) Load the history of a NeuralNet from a json file.
load_params([f, f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Return the output of the module’s forward method as a numpy array.
save_history(f) Saves the history of NeuralNet as a json file.
save_params([f, f_params, f_optimizer, …]) Saves the module’s parameters, history, and optimizer, not the whole object.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
get_default_callbacks  
get_params  
on_batch_end  
on_grad_computed  
fit(X, y, **fit_params)[source]

See NeuralNet.fit.

In contrast to NeuralNet.fit, y is non-optional to avoid mistakenly forgetting about y. However, y can be set to None in case it is derived dynamically from X.

skorch.toy

Contains toy functions and classes for quick prototyping and testing.

class skorch.toy.MLPModule(input_units=20, output_units=2, hidden_units=10, num_hidden=1, nonlin=ReLU(), output_nonlin=None, dropout=0, squeeze_output=False)[source]

A simple multi-layer perceptron module.

This can be adapted for usage in different contexts, e.g. binary and multi-class classification, regression, etc.

Parameters:
input_units : int (default=20)

Number of input units.

output_units : int (default=2)

Number of output units.

hidden_units : int (default=10)

Number of units in hidden layers.

num_hidden : int (default=1)

Number of hidden layers.

nonlin : torch.nn.Module

Non-linearity to apply after hidden layers.

output_nonlin : torch.nn.Module

Non-linearity to apply after last layer.

dropout : float (default=0)

Dropout rate. Dropout is applied between layers.

squeeze_output : bool (default=False)

Whether to squeeze output. Squeezing can be helpful if you wish your output to be 1-dimensional (e.g. for NeuralNetBinaryClassifier).

Methods

__call__(*input, **kwargs) Call self as a function.
add_module(name, module) Adds a child module to the current module.
apply(fn) Applies fn recursively to every submodule (as returned by .children()) as well as self.
children() Returns an iterator over immediate children modules.
cpu() Moves all model parameters and buffers to the CPU.
cuda([device]) Moves all model parameters and buffers to the GPU.
double() Casts all floating point parameters and buffers to double datatype.
eval() Sets the module in evaluation mode.
extra_repr() Set the extra representation of the module
float() Casts all floating point parameters and buffers to float datatype.
forward(X) Defines the computation performed at every call.
half() Casts all floating point parameters and buffers to half datatype.
load_state_dict(state_dict[, strict]) Copies parameters and buffers from state_dict into this module and its descendants.
modules() Returns an iterator over all modules in the network.
named_children() Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix]) Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([memo, prefix]) Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself
parameters() Returns an iterator over module parameters.
register_backward_hook(hook) Registers a backward hook on the module.
register_buffer(name, tensor) Adds a persistent buffer to the module.
register_forward_hook(hook) Registers a forward hook on the module.
register_forward_pre_hook(hook) Registers a forward pre-hook on the module.
register_parameter(name, param) Adds a parameter to the module.
reset_params() (Re)set all parameters.
state_dict([destination, prefix, keep_vars]) Returns a dictionary containing a whole state of the module.
to(*args, **kwargs) Moves and/or casts the parameters and buffers.
train([mode]) Sets the module in training mode.
type(dst_type) Casts all parameters and buffers to dst_type.
zero_grad() Sets gradients of all model parameters to zero.
share_memory  
forward(X)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_params()[source]

(Re)set all parameters.

skorch.toy.make_binary_classifier(squeeze_output=True, **kwargs)[source]

Return a multi-layer perceptron to be used with NeuralNetBinaryClassifier.

skorch.toy.make_classifier(output_nonlin=Softmax(), **kwargs)[source]

Return a multi-layer perceptron to be used with NeuralNetClassifier.

skorch.toy.make_regressor(output_units=1, **kwargs)[source]

Return a multi-layer perceptron to be used with NeuralNetRegressor.

skorch.utils

skorch utilities.

Should not have any dependency on other skorch packages.

class skorch.utils.Ansi[source]

An enumeration.

class skorch.utils.FirstStepAccumulator[source]

Store and retrieve the train step data.

This class simply stores the first step value and returns it.

For most uses, skorch.utils.FirstStepAccumulator is what you want, since the optimizer calls the train step exactly once. However, some optimizerss such as LBFGSs make more than one call. If in that case, you don’t want the first value to be returned (but instead, say, the last value), implement your own accumulator and make sure it is returned by NeuralNet.get_train_step_accumulator method.

Methods

get_step() Return the stored step.
store_step(step) Store the first step.
get_step()[source]

Return the stored step.

store_step(step)[source]

Store the first step.

skorch.utils.check_indexing(data)[source]

Perform a check how incoming data should be indexed and return an appropriate indexing function with signature f(data, index).

This is useful for determining upfront how data should be indexed instead of doing it repeatedly for each batch, thus saving some time.

skorch.utils.data_from_dataset(dataset, X_indexing=None, y_indexing=None)[source]

Try to access X and y attribute from dataset.

Also works when dataset is a subset.

Parameters:
dataset : skorch.dataset.Dataset or torch.utils.data.Subset

The incoming dataset should be a skorch.dataset.Dataset or a torch.utils.data.Subset of a skorch.dataset.Dataset.

X_indexing : function/callable or None (default=None)

If not None, use this function for indexing into the X data. If None, try to automatically determine how to index data.

y_indexing : function/callable or None (default=None)

If not None, use this function for indexing into the y data. If None, try to automatically determine how to index data.

skorch.utils.duplicate_items(*collections)[source]

Search for duplicate items in all collections.

Examples

>>> duplicate_items([1, 2], [3])
set()
>>> duplicate_items({1: 'a', 2: 'a'})
set()
>>> duplicate_items(['a', 'b', 'a'])
{'a'}
>>> duplicate_items([1, 2], {3: 'hi', 4: 'ha'}, (2, 3))
{2, 3}
skorch.utils.freeze_parameter(param)[source]

Convenience function to freeze a passed torch parameter. Used by skorch.callbacks.Freezer

skorch.utils.get_dim(y)[source]

Return the number of dimensions of a torch tensor or numpy array-like object.

skorch.utils.get_map_location(target_device, fallback_device='cpu')[source]

Determine the location to map loaded data (e.g., weights) for a given target device (e.g. ‘cuda’).

skorch.utils.is_skorch_dataset(ds)[source]

Checks if the supplied dataset is an instance of skorch.dataset.Dataset even when it is nested inside torch.util.data.Subset.

skorch.utils.multi_indexing(data, i, indexing=None)[source]

Perform indexing on multiple data structures.

Currently supported data types:

  • numpy arrays
  • torch tensors
  • pandas NDFrame
  • a dictionary of the former three
  • a list/tuple of the former three

i can be an integer or a slice.

Parameters:
data

Data of a type mentioned above.

i : int or slice

Slicing index.

indexing : function/callable or None (default=None)

If not None, use this function for indexing into the data. If None, try to automatically determine how to index data.

Examples

>>> multi_indexing(np.asarray([1, 2, 3]), 0)
1
>>> multi_indexing(np.asarray([1, 2, 3]), np.s_[:2])
array([1, 2])
>>> multi_indexing(torch.arange(0, 4), np.s_[1:3])
tensor([ 1.,  2.])
>>> multi_indexing([[1, 2, 3], [4, 5, 6]], np.s_[:2])
[[1, 2], [4, 5]]
>>> multi_indexing({'a': [1, 2, 3], 'b': [4, 5, 6]}, np.s_[-2:])
{'a': [2, 3], 'b': [5, 6]}
>>> multi_indexing(pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}), [1, 2])
   a  b
1  2  5
2  3  6
skorch.utils.noop(*args, **kwargs)[source]

No-op function that does nothing and returns None.

This is useful for defining scoring callbacks that do not need a target extractor.

skorch.utils.open_file_like(f, mode)[source]

Wrapper for opening a file

skorch.utils.params_for(prefix, kwargs)[source]

Extract parameters that belong to a given sklearn module prefix from kwargs. This is useful to obtain parameters that belong to a submodule.

Examples

>>> kwargs = {'encoder__a': 3, 'encoder__b': 4, 'decoder__a': 5}
>>> params_for('encoder', kwargs)
{'a': 3, 'b': 4}
skorch.utils.to_numpy(X)[source]

Generic function to convert a pytorch tensor to numpy.

Returns X when it already is a numpy array.

skorch.utils.to_tensor(X, device)[source]

Turn input data to torch tensor.

Parameters:
X : input data
Handles the cases:
  • PackedSequence
  • numpy array
  • torch Tensor
  • list or tuple of one of the former
  • dict with values of one of the former
device : str, torch.device

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.

Returns:
output : torch Tensor
skorch.utils.unfreeze_parameter(param)[source]

Convenience function to unfreeze a passed torch parameter. Used by skorch.callbacks.Unfreezer

Indices and tables