skorch documentation¶
A scikit-learn compatible neural network library that wraps PyTorch.
Introduction¶
The goal of skorch is to make it possible to use PyTorch with sklearn. This is achieved by providing a wrapper around PyTorch that has an sklearn interface.
skorch does not re-invent the wheel, instead getting as much out of your way as possible. If you are familiar with sklearn and PyTorch, you don’t have to learn any new concepts, and the syntax should be well known. (If you’re not familiar with those libraries, it is worth getting familiarized.)
Additionally, skorch abstracts away the training loop, making a
lot of boilerplate code obsolete. A simple net.fit(X, y)
is
enough. Out of the box, skorch works with many types of data, be
it PyTorch Tensors, NumPy arrays, Python dicts, and so
on. However, if you have other data, extending skorch is easy to
allow for that.
Overall, skorch aims at being as flexible as PyTorch while having a clean interface as sklearn.
If you use skorch, please use this BibTeX entry:
@manual{skorch,
author = {Marian Tietz and Thomas J. Fan and Daniel Nouri and Benjamin Bossan and {skorch Developers}},
title = {skorch: A scikit-learn compatible neural network library that wraps PyTorch},
month = jul,
year = 2017,
url = {https://skorch.readthedocs.io/en/stable/}
}
User’s Guide¶
Installation¶
pip installation¶
To install with pip, run:
python -m pip install -U skorch
We recommend to use a virtual environment for this.
From source¶
If you would like to use the must recent additions to skorch or help development, you should install skorch from source.
Using conda¶
You need a working conda installation. Get the correct miniconda for your system from here.
If you just want to use skorch, use:
git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda env create
source activate skorch
python -m pip install .
If you want to help developing, run:
git clone https://github.com/skorch-dev/skorch.git
cd skorch
conda env create
source activate skorch
python -m pip install -e .
py.test # unit tests
pylint skorch # static code checks
Using pip¶
If you just want to use skorch, use:
git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
python -m pip install -r requirements.txt
# install pytorch version for your system (see below)
python -m pip install .
If you want to help developing, run:
git clone https://github.com/skorch-dev/skorch.git
cd skorch
# create and activate a virtual environment
python -m pip install -r requirements.txt
# install pytorch version for your system (see below)
python -m pip install -r requirements-dev.txt
python -m pip install -e .
py.test # unit tests
pylint skorch # static code checks
PyTorch¶
PyTorch is not covered by the dependencies, since the PyTorch version you need is dependent on your OS and device. For installation instructions for PyTorch, visit the PyTorch website. skorch officially supports the last four minor PyTorch versions, which currently are:
- 1.9.1
- 1.10.2
- 1.11.0
- 1.12.0
However, that doesn’t mean that older versions don’t work, just that they aren’t tested. Since skorch mostly relies on the stable part of the PyTorch API, older PyTorch versions should work fine.
In general, running this to install PyTorch should work (assuming CUDA 11.1):
# using conda:
conda install pytorch cudatoolkit==11.1 -c pytorch
# using pip
python -m pip install torch
Quickstart¶
Training a model¶
Below, we define our own PyTorch Module
and train
it on a toy classification dataset using skorch
NeuralNetClassifier
:
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
from skorch import NeuralNetClassifier
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Module):
def __init__(self, num_units=10, nonlin=nn.ReLU()):
super(MyModule, self).__init__()
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dropout = nn.Dropout(0.5)
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 2)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = self.dropout(X)
X = self.nonlin(self.dense1(X))
X = self.output(X)
return X
net = NeuralNetClassifier(
MyModule,
max_epochs=10,
criterion=nn.CrossEntropyLoss(),
lr=0.1,
# Shuffle training data on each epoch
iterator_train__shuffle=True,
)
net.fit(X, y)
y_proba = net.predict_proba(X)
Note
In this example, instead of using the standard softmax
non-linearity
with NLLLoss
as criterion, no output non-linearity is
used and CrossEntropyLoss
as criterion
. The reason is
that the use of softmax
can lead to numerical instability in some cases.
In an sklearn Pipeline¶
Since NeuralNetClassifier
provides an sklearn-compatible
interface, it is possible to put it into an sklearn
Pipeline
:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
pipe = Pipeline([
('scale', StandardScaler()),
('net', net),
])
pipe.fit(X, y)
y_proba = pipe.predict_proba(X)
Grid search¶
Another advantage of skorch is that you can perform an sklearn
GridSearchCV
or
RandomizedSearchCV
:
from sklearn.model_selection import GridSearchCV
params = {
'lr': [0.01, 0.02],
'max_epochs': [10, 20],
'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')
gs.fit(X, y)
print(gs.best_score_, gs.best_params_)
Tutorials¶
The following are examples and notebooks on how to use skorch.
- Basic Usage - Explores the basics of the skorch API. Run in Google Colab 💻
- MNIST with scikit-learn and skorch - Define and train a simple neural network with PyTorch and use it with skorch. Run in Google Colab 💻
- Benchmarks skorch vs pure PyTorch - Compares the performance of skorch and using pure PyTorch on MNIST.
- Transfer Learning with skorch - Train a neural network using transfer learning with skorch. Run in Google Colab 💻
- Image Segmentation with UNets - Use transfer learning to train a UNet model for image segmentation.
- Using skorch with Dask - Using Dask to parallelize grid search across GPUs.
- World level language modeling RNN - Uses skorch to train a language model.
- Seq2Seq Translation using skorch - Translation with a seqeuence to sequence network.
- Advanced Usage - Dives deep into the inner works of skorch. Run in Google Colab 💻
- Gaussian Processes - Train Gaussian Processes with the help of GPyTorch. Run in Google Colab 💻
- Huggingface Finetunging - Fine-tune a BERT model for text classification with the huggingface transformers library and skorch. Run in Google Colab 💻
NeuralNet¶
Using NeuralNet¶
NeuralNet
and the derived classes are the main touch point
for the user. They wrap the PyTorch Module
while
providing an interface that should be familiar for sklearn users.
Define your Module
the same way as you always do.
Then pass it to NeuralNet
, in conjunction with a PyTorch
criterion. Finally, you can call fit()
and predict()
, as with an sklearn
estimator. The finished code could look something like this:
class MyModule(torch.nn.Module):
...
net = NeuralNet(
module=MyModule,
criterion=torch.nn.NLLLoss,
)
net.fit(X, y)
y_pred = net.predict(X_valid)
Let’s see what skorch did for us here:
- wraps the PyTorch
Module
in an sklearn interface - converts
numpy.ndarray
s to PyTorchTensor
s - abstracts away the fit loop
- takes care of batching the data
You therefore have a lot less boilerplate code, letting you focus on what matters. At the same time, skorch is very flexible and can be extended with ease, getting out of your way as much as possible.
Initialization¶
In general, when you instantiate the NeuralNet
instance,
only the given arguments are stored. They are stored exactly as you
pass them to NeuralNet
. For instance, the module
will
remain uninstantiated. This is to make sure that the arguments you
pass are not touched afterwards, which makes it possible to clone the
NeuralNet
instance, for instance.
Only when the fit()
or
initialize()
method are called, are the
different attributes of the net, such as the module
, initialized.
An initialized attribute’s name always ends on an underscore; e.g., the
initialized module
is called module_
. (This is the same
nomenclature as sklearn uses.) Thefore, you always know which
attributes you set and which ones were created by NeuralNet
.
The only exception is the history attribute, which is not set by the user.
Most important arguments and methods¶
A complete explanation of all arguments and methods of
NeuralNet
are found in the skorch API documentation. Here we
focus on the main ones.
module¶
This is where you pass your PyTorch Module
.
Ideally, it should not be instantiated. Instead, the init arguments
for your module should be passed to NeuralNet
with the
module__
prefix. E.g., if your module takes the arguments
num_units
and dropout
, the code would look like this:
class MyModule(torch.nn.Module):
def __init__(self, num_units, dropout):
...
net = NeuralNet(
module=MyModule,
module__num_units=100,
module__dropout=0.5,
criterion=torch.nn.NLLLoss,
)
It is, however, also possible to pass an instantiated module, e.g. a
PyTorch Sequential
instance.
Note that skorch does not automatically apply any nonlinearities to
the outputs (except internally when determining the PyTorch
NLLLoss
, see below). That means that if you have a
classification task, you should make sure that the final output
nonlinearity is a softmax. Otherwise, when you call
predict_proba()
, you won’t get actual
probabilities.
criterion¶
This should be a PyTorch (-compatible) criterion.
When you use the NeuralNetClassifier
, the criterion is set
to PyTorch NLLLoss
by default. Furthermore, if you
don’t change it loss to another criterion,
NeuralNetClassifier
assumes that the module returns
probabilities and will automatically apply a logarithm on them (which
is what NLLLoss
expects).
For NeuralNetRegressor
, the default criterion is PyTorch
MSELoss
.
After initializing the NeuralNet
, the initialized criterion
will stored in the criterion_
attribute.
optimizer¶
This should be a PyTorch optimizer, e.g. SGD
.
After initializing the NeuralNet
, the initialized optimizer
will stored in the optimizer_
attribute. During initialization
you can define param groups, for example to set different learning
rates for certain parameters. The parameters are selected by name with
support for wildcards (globbing):
optimizer__param_groups=[
('embedding.*', {'lr': 0.0}),
('linear0.bias', {'lr': 1}),
]
Your use case may require an optimizer whose signature differs from a
default PyTorch optimizer’s signature. In that case, you can define a
custom function that reroutes the arguments as needed and pass it to
the optimizer
parameter:
# custom optimizer to encapsulate Adam
def make_lookahead(parameters, optimizer_cls, k, alpha, **kwargs):
optimizer = optimizer_cls(parameters, **kwargs)
return Lookahead(optimizer=optimizer, k=k, alpha=alpha)
net = NeuralNetClassifier(
...,
optimizer=make_lookahead,
optimizer__optimizer_cls=torch.optim.Adam,
optimizer__weight_decay=1e-2,
optimizer__k=5,
optimizer__alpha=0.5,
lr=1e-3)
lr¶
The learning rate. This argument exists for convenience, since it
could also be set by optimizer__lr
instead. However, it is used so
often that we provided this shortcut. If you set both lr
and
optimizer__lr
, the latter have precedence.
max_epochs¶
The maximum number of epochs to train with each
fit()
call. When you call
fit()
, the net will train for this many
epochs, except if you interrupt training before the end (e.g. by using
an early stopping callback or interrupt manually with ctrl+c).
If you want to change the number of epochs to train, you can either
set a different value for max_epochs
, or you call
fit_loop()
instead of
fit()
and pass the desired number of
epochs explicitly:
net.fit_loop(X, y, epochs=20)
batch_size¶
This argument controls the batch size for iterator_train
and
iterator_valid
at the same time. batch_size=128
is thus a
convenient shortcut for explicitly typing
iterator_train__batch_size=128
and
iterator_valid__batch_size=128
. If you set all three arguments,
the latter two will have precedence.
train_split¶
This determines the NeuralNet
’s internal train/validation
split. By default, 20% of the incoming data is reserved for
validation. If you set this value to None
, all the data is used
for training.
For more details, please look at dataset.
callbacks¶
For more details on the callback classes, please look at callbacks.
By default, NeuralNet
and its subclasses start with a couple
of useful callbacks. Those are defined in the
get_default_callbacks()
method and
include, for instance, callbacks for measuring and printing model
performance.
In addition to the default callbacks, you may provide your own
callbacks. There are a couple of ways to pass callbacks to the
NeuralNet
instance. The easiest way is to pass a list of all
your callbacks to the callbacks
argument:
net = NeuralNet(
module=MyModule,
callbacks=[
MyCallback1(...),
MyCallback2(...),
],
)
Inside the NeuralNet
instance, each callback will receive a
separate name. Since we provide no name in the example above, the
class name will taken, which will lead to a name collision in case of
two or more callbacks of the same class. This is why it is better to
initialize the callbacks with a list of tuples of name and callback
instance, like this:
net = NeuralNet(
module=MyModule,
callbacks=[
('cb1', MyCallback1(...)),
('cb2', MyCallback2(...)),
],
)
This approach of passing a list of name, instance tuples should be
familiar to users of sklearnPipeline
s
and FeatureUnion
s.
An additonal advantage of this way of passing callbacks is that it allows to pass arguments to the callbacks by name (using the double-underscore notation):
net.set_params(callbacks__cb1__foo=123, callbacks__cb2__bar=456)
Use this, for instance, when trying out different callback parameters in a grid search.
Note: The user-defined callbacks are always called in the same order
as they appeared in the list. If there are dependencies between the
callbacks, the user has to make sure that the order respects them.
Also note that the user-defined callbacks will be called after the
default callbacks so that they can make use of the things provided by
the default callbacks. The only exception is the default callback
PrintLog
, which is always called last.
warm_start¶
This argument determines whether each
fit()
call leads to a re-initialization of
the NeuralNet
or not. By default, when calling
fit()
, the parameters of the net are
initialized, so your previous training progress is lost (consistent
with the sklearn fit()
calls). In contrast, with
warm_start=True
, each fit()
call will
continue from the most recent state.
device¶
As the name suggests, this determines which computation device should
be used. If set to cuda
, the incoming data will be transferred to
CUDA before being passed to the PyTorch Module
. The
device parameter adheres to the general syntax of the PyTorch device
parameter.
initialize()¶
As mentioned earlier, upon instantiating the NeuralNet
instance, the net’s components are not yet initialized. That means,
e.g., that the weights and biases of the layers are not yet set. This
only happens after the initialize()
call.
However, when you call fit()
and the net
is not yet initialized, initialize()
is
called automatically. You thus rarely need to call it manually.
The initialize()
method itself calls a
couple of other initialization methods that are specific to each
component. E.g., initialize_module()
is
responsible for initializing the PyTorch module. Therefore, if you
have special needs for initializing the module, it is enough to
override initialize_module()
, you don’t
need to override the whole initialize()
method.
fit(X, y)¶
This is one of the main methods you will use. It contains everything required to train the model, be it batching of the data, triggering the callbacks, or handling the internal validation set.
In general, we assume there to be an X
and a y
. If you have
more input data than just one array, it is possible for X
to be a
list or dictionary of data (see dataset). And if your
task does not have an actual y
, you may pass y=None
.
If you fit with a PyTorch Dataset
and don’t
explicitly pass y
, several components down the line might not work
anymore, since sklearn sometimes requires an explicit y
(e.g. for
scoring). In general, PyTorch Dataset
s
should work, though.
In addition to fit()
, there is also the
partial_fit()
method, known from some
sklearn estimators. partial_fit()
allows
you to continue training from your current status, even if you set
warm_start=False
. A further use case for
partial_fit()
is when your data does not
fit into memory and you thus need to have several training steps.
Tip :
skorch gracefully catches the KeyboardInterrupt
exception. Therefore, during a training run, you can send a
KeyboardInterrupt
signal without the Python process exiting
(typically, KeyboardInterrupt
can be triggered by ctrl+c or, in
a Jupyter notebook, by clicking Kernel -> Interrupt). This way, when
your model has reached a good score before max_epochs
have been
reached, you can dynamically stop training.
predict(X) and predict_proba(X)¶
These methods perform an inference step on the input data and return
numpy.ndarray
s. By default,
predict_proba()
will return whatever it is
that the module
’s forward()
method
returns, cast to a numpy.ndarray
. If
forward()
returns multiple outputs as a tuple,
only the first output is used, the rest is discarded.
If the forward()
-output can not be cast to a
numpy.ndarray
, or if you need access to all outputs in the
multiple-outputs case, consider using either of
forward()
or
forward_iter()
methods to generate outputs
from the module
. Alternatively, you may directly call
net.module_(X)
.
In case of NeuralNetClassifier
, the
predict()
method tries to
return the class labels by applying the argmax over the last axis of
the result of
predict_proba()
.
Obviously, this only makes sense if
predict_proba()
returns
class probabilities. If this is not true, you should just use
predict_proba()
.
score(X, y)¶
This method returns the mean accuracy on the given data and labels for
classifiers and the coefficient of determination R^2 of the prediction for
regressors. NeuralNet
still has no score method. If you need it,
you have to implement it yourself.
model persistence¶
In general there are different ways of saving and loading models, each with their own advantages and disadvantages. More details and usage examples can be found here: Saving and Loading.
If you would like to use pickle (the default way when using scikit-learn models), this is possible with skorch nets. This saves the whole net including hyperparameters etc. The advantage is that you can restore everything to exactly the state it was before. The disadvantage is it’s easier for code changes to break your old saves.
Additionally, it is possible to save and load specific attributes of
the net, such as the module
, optimizer
, or history
, by
calling save_params()
and
load_params()
. This is useful if you’re
only interested in saving a particular part of your model, and is more
robust to code changes.
Finally, it is also possible to use callbacks to save and load models,
e.g. Checkpoint
. Those should be used if you need to have
your model saved or loaded at specific times, e.g. at the start or end
of the training process.
Input data¶
Regular data¶
skorch supports numerous input types for data. Regular input types
that should just work are numpy arrays, torch tensors, scipy sparse
CSR matrices, and pandas DataFrames (see also
DataFrameTransformer
).
Typically, your task should involve an X
and a y
. If you’re
dealing with a task that doesn’t require a target (say, training an
autoencoder), you can just pass y=None
. Make sure your loss
function deals with this appropriately.
Datasets¶
Dataset
s are also supported, with the
requirement that they should return exactly two items (X
and
y
). For more information on that, take a look at the
Dataset documentation.
Many PyTorch libraries, like torchvision, implement their own
Dataset
s. These usually work seamlessly with skorch, as long as
their __getitem__
methods return two outputs. In case they don’t,
consider overriding the __getitem__
class and re-arranging the
ouputs so that __getitem__
returns exactly two elements. If the
original implementation returns more than two elements, take a look at
the next section to get an idea how to deal with that.
Multiple input arguments¶
In some cases, the input actually consists of multiple inputs. E.g., in a text classification task, you might have an array that contains the integers representing the tokens for each sample, and another array containing the number of tokens of each sample. skorch has you covered here as well.
You could supply a list or tuple with all your inputs (net.fit([tokens,
num_tokens], y)
), but we actually recommend another approach. The best
way is to pass the different arguments as a dictionary. Then the keys
of that dictionary have to correspond to the argument names of your
module’s forward
method. Below is an example:
X_dict = {'tokens': tokens, 'num_tokens': num_tokens}
class MyModule(nn.Module):
def forward(self, tokens, num_tokens): # <- same names as in your dict
...
net = NeuralNet(MyModule, ...)
net.fit(X_dict, y)
As you can see, the forward
method takes arguments with exactly
the same name as the keys in the dictionary. This is how the different
inputs are matched. To make this work with
GridSearchCV
, please use
SliceDict
.
Using a dict should cover most use cases that involve multiple inputs. However, it will fail if your inputs have different sizes. E.g., if your array of tokens has 1000 elements but your array of number of tokens has 2000 elements, this would fail. The main reason for this is batching: How can we know which elements of the two arrays belong in the same batch?
If your input consists of multiple inputs with different sizes, your best bet is to implement your own dataset class. That class should know how it deals with the different inputs, i.e. which elements belong to the same sample. Again, please refer to the Dataset section for more details.
Special arguments¶
In addition to the arguments explicitly listed for
NeuralNet
, there are some arguments with special prefixes,
as shown below:
class MyModule(torch.nn.Module):
def __init__(self, num_units, dropout):
...
net = NeuralNet(
module=MyModule,
module__num_units=100,
module__dropout=0.5,
criterion=torch.nn.NLLLoss,
criterion__weight=weight,
optimizer=torch.optim.SGD,
optimizer__momentum=0.9,
)
Those arguments are used to initialize your module
, criterion
,
etc. They are not fixed because we cannot know them in advance; in
fact, you can define any parameter for your module
or other
components.
Callbacks¶
Callbacks provide a flexible way to customize the behavior of your
NeuralNet
training without the need to write subclasses.
You will often find callbacks writing to or reading from the
history attribute. Therefore, if you would like to
log the net’s behavior or do something based on the past behavior,
consider using net.history
.
This page will not explain all existing callbacks. For that, please
look at skorch.callbacks
.
Callback base class¶
The base class for each callback is Callback
. If you would
like to write your own callbacks, you should inherit from this class.
A guide and practical example on how to write your own callbacks is
shown in this notebook.
In general, remember this:
- They should inherit from the base class.
- They should implement at least one of the
on_
methods provided by the parent class (see below). - As argument, the methods first get the
NeuralNet
instance, and, where appropriate, the local data (e.g. the data from the current batch). The method should also have**kwargs
in the signature for potentially unused arguments.
Callback methods to override¶
The following methods could potentially be overriden when implementing your own callbacks.
initialize()¶
If you have attributes that should be reset when the model is re-initialized, those attributes should be set in this method.
on_train_begin(net, X, y)¶
Called once at the start of the training process (e.g. when calling fit).
on_train_end(net, X, y)¶
Called once at the end of the training process.
on_epoch_begin(net, dataset_train, dataset_valid)¶
Called once at the start of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.
on_epoch_end(net, dataset_train, dataset_valid)¶
Called once at the end of the epoch, i.e. possibly several times per fit call. Gets training and validation data as additional input.
on_batch_begin(net, batch, training)¶
Called once before each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input. Also includes a bool indicating if this is a training batch or not.
on_batch_end(net, batch, training, loss, y_pred)¶
Called once after each batch of data is processed, i.e. possibly several times per epoch. Gets batch data as additional input.
on_grad_computed(net, named_parameters, Xi, yi)¶
Called once per batch after gradients have been computed but before an update step was performed. Gets the module parameters as additional input as well as the batch data. Useful if you want to tinker with gradients.
Setting callback parameters¶
You can set specific callback parameters using the ususal set_params interface on the network by using the callbacks__ prefix and the callback’s name. For example to change the scoring order of the train loss you can write this:
net = NeuralNet()
net.set_params(callbacks__train_loss__lower_is_better=False)
Changes will be applied on initialization and callbacks that are changed using set_params will be re-initialized.
The name you use to address the callback can be chosen during
initialization of the network and defaults to the class name.
If there is a conflict, the conflicting names will be made unique
by appending a count suffix starting at 1, e.g.
EpochScoring_1
, EpochScoring_2
, etc.
Deactivating callbacks¶
If you would like to (temporarily) deactivate a callback, you can do so by setting its parameter to None. E.g., if you have a callback called ‘my_callback’, you can deactivate it like this:
net = NeuralNet(
module=MyModule,
callbacks=[('my_callback', MyCallback())],
)
# now deactivate 'my_callback':
net.set_params(callbacks__my_callback=None)
This also works with default callbacks.
Deactivating callbacks can be especially useful when you do a
parameter search (say with sklearn
GridSearchCV
). If, for instance, you
use a callback for learning rate scheduling (e.g. via
LRScheduler
) and want to test its usefulness, you can
compare the performance once with and once without the callback.
To completely disable all callbacks, including default callbacks,
set callbacks="disable"
.
Scoring¶
skorch provides two callbacks that calculate scores by default,
EpochScoring
and BatchScoring
. They work basically
in the same way, except that EpochScoring
calculates scores
after each epoch and BatchScoring
after each batch. Use the
former if averaging of batch-wise scores is imprecise (say for AUC
score) and the latter if you are very tight for memory.
In general, these scoring callbacks are useful when the default scores
determined by the NeuralNet
are not enough. They allow you
to easily add new metrics to be logged during training. For an example
of how to add a new score to your model, look at this notebook.
The first argument to both callbacks is name
and should be a
string. This determines the column name of the score shown by the
PrintLog
after each epoch.
Next comes the scoring
parameter. For eager sklearn users, this
should be familiar, since it works exactly the same as in sklearn
GridSearchCV
,
RandomizedSearchCV
,
cross_val_score()
, etc. For those who
are unfamiliar, here is a short explanation:
- If you pass a string, sklearn makes a look-up for a score with
that name. Examples would be
'f1'
and'roc_auc'
. - If you pass
None
, the model’sscore
method is used. By default,NeuralNet
and its subclasses don’t provide ascore
method, but you can easily implement your own. If you do, it should takeX
andy
(the target) as input and return a scalar as output. - Finally, you can pass a function/callable. In that case, this
function should have the signature
func(net, X, y)
and return a scalar.
More on sklearn’s model evaluation can be found in this notebook.
The lower_is_better
parameter determines whether lower scores
should be considered as better (e.g. log loss) or worse
(e.g. accuracy). This information is used to write a <name>_best
value to the net’s history
. E.g., if your score is f1 score and is
called 'f1'
, you should set lower_is_better=False
. The
history
will then contain an entry for 'f1'
, which is the
score itself, and an entry for 'f1_best'
, which says whether this
is the as of yet best f1 score.
on_train
is used to indicate whether training or validation data
should be used to determine the score. By default, it is set to
validation.
Finally, you may have to provide your own target_extractor
. This
should be a function or callable that is applied to the target before
it is passed to the scoring function. The main reason why we need this
is that sometimes, the target is not of a form expected by sklearn and
we need to process it before passing it on.
On top of the two described scoring callbacks, skorch also provides
PassthroughScoring
. This callback does not actually
calculate any new scores. Instead it uses an existing score that is
calculated for each batch (the train loss, for example) and determines
the average of this score, which is then written to the epoch level of
the net’s history
. This is very useful if the score was already
calculated and logged on the batch level and you’re only interested to
see the averaged score on the epoch level.
For this callback, you only need to provide the name
of the score
in the history
. Moreover, you may again specify if
lower_is_better
and if the score should be calculated on_train
or not.
Note
Both BatchScoring
and PassthroughScoring
honor the batch size when calculating the average. This can
make a difference when not all batch sizes are equal, which
is typically the case because the last batch of an epoch
contains fewer samples than the rest.
Checkpoint¶
The Checkpoint
callback creates a checkpoint of your model
after each epoch that met certain criteria. By default, the condition
is that the validation loss has improved, however you may change this
by specifying the monitor
parameter. It can take three types of
arguments:
None
: The model is saved after each epoch;- string: The model checks whether the last entry in the model
history
for that key is truthy. This is useful in conjunction with scores determined by a scoring callback. They write a<score>_best
entry to thehistory
, which can be used for checkpointing. By default, theCheckpoint
callback looks at'valid_loss_best'
; - function or callable: In that case, the function should take the
NeuralNet
instance as sole input and return a bool as output.
To specify where and how your model is saved, change the arguments
starting with f_
:
f_params
: to save model parametersf_optimizer
: to save optimizer statef_history
: to save training historyf_pickle
: to pickle the entire model object.
Please refer to Saving and Loading for more information about restoring your network from a checkpoint.
Learning rate schedulers¶
The LRScheduler
callback allows the use of the various
learning rate schedulers defined in torch.optim.lr_scheduler
,
such as ReduceLROnPlateau
, which
allows dynamic learning rate reducing based on a given value to
monitor, or CyclicLR
, which cycles
the learning rate between two boundaries with a constant frequency.
Here’s a network that uses a callback to set a cyclic learning rate:
from skorch.callbacks import LRScheduler
from torch.optim.lr_scheduler import CyclicLR
net = NeuralNet(
module=MyModule,
callbacks=[
('lr_scheduler',
LRScheduler(policy=CyclicLR,
base_lr=0.001,
max_lr=0.01)),
],
)
As with other callbacks, you can use set_params to set parameters,
and thus search learning rate scheduler parameters using
GridSearchCV
or similar. An
example:
from sklearn.model_selection import GridSearchCV
search = GridSearchCV(
net,
param_grid={'callbacks__lr_scheduler__max_lr': [0.01, 0.1, 1.0]},
)
Dataset¶
This module contains classes and functions related to data handling.
ValidSplit¶
This class is responsible for performing the NeuralNet
’s
internal cross validation. For this, it sticks closely to the sklearn
standards. For more information on how sklearn handles cross
validation, look here.
The first argument that ValidSplit
takes is cv
. It works
analogously to the cv
argument from sklearn
GridSearchCV
,
cross_val_score()
, etc. For those not
familiar, here is a short explanation of what you may pass:
None
: Use the default 3-fold cross validation.- integer: Specifies the number of folds in a
(Stratified)KFold
, - float: Represents the proportion of the dataset to include in the
validation split (e.g.
0.2
for 20%). - An object to be used as a cross-validation generator.
- An iterable yielding train, validation splits.
Furthermore, ValidSplit
takes a stratified
argument that
determines whether a stratified split should be made (only makes sense
for discrete targets), and a random_state
argument, which is used
in case the cross validation split has a random component.
One difference to sklearn’s cross validation is that skorch makes only a single split. In sklearn, you would expect that in a 5-fold cross validation, the model is trained 5 times on the different combination of folds. This is often not desirable for neural networks, since training takes a lot of time. Therefore, skorch only ever makes one split.
If you would like to have all splits, you can still use skorch in
conjunction with the sklearn functions, as you would do with any
other sklearn-compatible estimator. Just remember to set
train_split=None
, so that the whole dataset is used for
training. Below is shown an example of making out-of-fold predictions
with skorch and sklearn:
net = NeuralNetClassifier(
module=MyModule,
train_split=None,
)
from sklearn.model_selection import cross_val_predict
y_pred = cross_val_predict(net, X, y, cv=5)
Dataset¶
In PyTorch, we have the concept of a
Dataset
and a
DataLoader
. The former is purely the
container of the data and only needs to implement __len__()
and
__getitem__(<int>)
. The latter does the heavy lifting, such as
sampling, shuffling, and distributed processing.
skorch uses the PyTorch DataLoader
s by default.
skorch supports PyTorch’s Dataset
when calling
fit()
or
partial_fit()
. Details on how to use PyTorch’s
Dataset
with skorch, can be found in
How do I use a PyTorch Dataset with skorch?.
In order to support other data formats, we provide our own
Dataset
class that is compatible with:
numpy.ndarray
s- PyTorch
Tensor
s - scipy sparse CSR matrices
- pandas DataFrames or Series
Note that currently, sparse matrices are cast to dense arrays during
batching, given that PyTorch support for sparse matrices is still very
incomplete. If you would like to prevent that, you need to override
the transform
method of Dataset
.
In addition to the types above, you can pass dictionaries or lists of
one of those data types, e.g. a dictionary of
numpy.ndarray
s. When you pass dictionaries, the keys of the
dictionaries are used as the argument name for the
forward()
method of the net’s
module
. Similarly, the column names of pandas DataFrame
s are
used as argument names. The example below should illustrate how to use
this feature:
import numpy as np
import torch
import torch.nn.functional as F
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense_a = torch.nn.Linear(10, 100)
self.dense_b = torch.nn.Linear(20, 100)
self.output = torch.nn.Linear(200, 2)
def forward(self, key_a, key_b):
hid_a = F.relu(self.dense_a(key_a))
hid_b = F.relu(self.dense_b(key_b))
concat = torch.cat((hid_a, hid_b), dim=1)
out = F.softmax(self.output(concat))
return out
net = NeuralNetClassifier(MyModule)
X = {
'key_a': np.random.random((1000, 10)).astype(np.float32),
'key_b': np.random.random((1000, 20)).astype(np.float32),
}
y = np.random.randint(0, 2, size=1000)
net.fit(X, y)
Note that the keys in the dictionary X
exactly match the argument
names in the forward()
method. This way, you
can easily work with several different types of input features.
The Dataset
from skorch makes the assumption that you always
have an X
and a y
, where X
represents the input data and
y
the target. However, you may leave y=None
, in which case
Dataset
returns a dummy variable.
Dataset
applies a transform final transform on the data
before passing it on to the PyTorch
DataLoader
. By default, it replaces y
by a dummy variable in case it is None
. If you would like to
apply your own transformation on the data, you should subclass
Dataset
and override the
transform()
method, then pass your
custom class to NeuralNet
as the dataset
argument.
Saving and Loading¶
General approach¶
skorch provides several ways to persist your model. First it is
possible to store the model using Python’s pickle
function. This saves the whole model, including hyperparameters. This
is useful when you don’t want to initialize your model before loading
its parameters, or when your NeuralNet
is part of an sklearn
Pipeline
:
net = NeuralNet(
module=MyModule,
criterion=torch.nn.NLLLoss,
)
model = Pipeline([
('my-features', get_features()),
('net', net),
])
model.fit(X, y)
# saving
with open('some-file.pkl', 'wb') as f:
pickle.dump(model, f)
# loading
with open('some-file.pkl', 'rb') as f:
model = pickle.load(f)
The disadvantage of pickling is that if your underlying code changes, unpickling might raise errors. Also, some Python code (e.g. lambda functions) cannot be pickled.
For this reason, we provide a second method for persisting your model.
To use it, call the save_params()
and
load_params()
method on
NeuralNet
. Under the hood, this saves the module
’s
state_dict
, i.e. only the weights and biases of the module
.
This is more robust to changes in the code but requires you to
initialize a NeuralNet
to load the parameters again:
net = NeuralNet(
module=MyModule,
criterion=torch.nn.NLLLoss,
)
model = Pipeline([
('my-features', get_features()),
('net', net),
])
model.fit(X, y)
net.save_params(f_params='some-file.pkl')
new_net = NeuralNet(
module=MyModule,
criterion=torch.nn.NLLLoss,
)
new_net.initialize() # This is important!
new_net.load_params(f_params='some-file.pkl')
In addition to saving the model parameters, the history and optimizer
state can be saved by including the f_history and f_optimizer
keywords to save_params()
and
load_params()
on NeuralNet
. This
feature can be used to continue training:
net = NeuralNet(
module=MyModule
criterion=torch.nn.NLLLoss,
)
net.fit(X, y, epochs=2) # Train for 2 epochs
net.save_params(
f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')
new_net = NeuralNet(
module=MyModule
criterion=torch.nn.NLLLoss,
)
new_net.initialize() # This is important!
new_net.load_params(
f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')
new_net.fit(X, y, epochs=2) # Train for another 2 epochs
Note
In order to use this feature, the history must only contain JSON encodable Python data structures. Numpy and PyTorch types should not be in the history.
Note
save_params()
does not store
learned attributes on the net. E.g.,
skorch.classifier.NeuralNetClassifier
remembers the
classes it encountered during training in the classes_
attribute. This attribute will be missing after
load_params()
. Therefore, if you need
it, you should pickle.dump()
the whole net.
Using callbacks¶
skorch provides Checkpoint
, TrainEndCheckpoint
,
and LoadInitState
callbacks to handle saving and loading
models during training. To demonstrate these features, we generate a
dataset and create a simple module:
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
X, y = make_classification(1000, 10, n_informative=5, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Sequential):
def __init__(self, num_units=10):
super().__init__(
nn.Linear(10, num_units),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(num_units, 10),
nn.Linear(10, 2),
nn.Softmax(dim=-1)
)
Then we create two different checkpoint callbacks and configure them
to save the model parameters, optimizer, and history into a directory
named 'exp1'
:
# First run
from skorch.callbacks import Checkpoint, TrainEndCheckpoint
from skorch import NeuralNetClassifier
cp = Checkpoint(dirname='exp1')
train_end_cp = TrainEndCheckpoint(dirname='exp1')
net = NeuralNetClassifier(
MyModule, lr=0.5, callbacks=[cp, train_end_cp]
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
1 0.6200 0.8209 0.4765 + 0.0232
2 0.3644 0.8557 0.3474 + 0.0238
3 0.2875 0.8806 0.3201 + 0.0214
4 0.2514 0.8905 0.3080 + 0.0237
5 0.2333 0.9154 0.2844 + 0.0203
6 0.2177 0.9403 0.2164 + 0.0215
7 0.2194 0.9403 0.2159 + 0.0220
8 0.2027 0.9403 0.2299 0.0202
9 0.1864 0.9254 0.2313 0.0196
10 0.2024 0.9353 0.2333 0.0221
By default, Checkpoint
observes valid_loss
metric and
saves the model when the metric improves. This is indicated by the
+
mark in the cp
column of the logs.
On our first run, the validation loss did not improve after the 7th
epoch. We can lower the learning rate and continue training from this
checkpoint by using LoadInitState
:
from skorch.callbacks import LoadInitState
cp = Checkpoint(dirname='exp1')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
MyModule, lr=0.1, callbacks=[cp, load_state]
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
8 0.1939 0.9055 0.2626 + 0.0238
9 0.2055 0.9353 0.2031 + 0.0239
10 0.1992 0.9453 0.2101 0.0182
11 0.2033 0.9453 0.1947 + 0.0211
12 0.1825 0.9104 0.2515 0.0185
13 0.2010 0.9453 0.1927 + 0.0187
14 0.1508 0.9453 0.1952 0.0198
15 0.1679 0.9502 0.1905 + 0.0181
16 0.1516 0.9453 0.1864 + 0.0192
17 0.1576 0.9453 0.1804 + 0.0184
The LoadInitState
callback is executed once in the beginning
of the training procedure and initializes model, history, and
optimizer parameters from a specified checkpoint (if it exists). In
our case, the previous checkpoint was created at the end of epoch 7,
so the second run resumes from epoch 8. With a lower learning rate,
the validation loss was able to improve!
Notice that in the first run we included a TrainEndCheckpoint
in the list of callbacks. As its name suggests, this callback creates
a checkpoint at the end of training. As before, we can pass it to
LoadInitState
to continue training:
cp_from_final = Checkpoint(dirname='exp1', fn_prefix='from_train_end_')
load_state = LoadInitState(train_end_cp)
net = NeuralNetClassifier(
MyModule, lr=0.1, callbacks=[cp_from_final, load_state]
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
11 0.1663 0.9453 0.2166 + 0.0282
12 0.1880 0.9403 0.2237 0.0178
13 0.1813 0.9353 0.1993 + 0.0161
14 0.1744 0.9353 0.1955 + 0.0150
15 0.1538 0.9303 0.2053 0.0077
16 0.1473 0.9403 0.1947 + 0.0078
17 0.1563 0.9254 0.1989 0.0074
18 0.1558 0.9403 0.1877 + 0.0075
19 0.1534 0.9254 0.2318 0.0074
20 0.1779 0.9453 0.1814 + 0.0074
In this run, training started at epoch 11, continuing from the end of
the first run which ended at epoch 10. We created a new
Checkpoint
callback with fn_prefix
set to
'from_train_end_'
to prefix the saved filenames with
'from_train_end_'
to make sure this checkpoint does not override
the checkpoint from the previous run.
Since our MyModule
class allows num_units
to be adjusted, we
can start a new experiment by changing the dirname
:
cp = Checkpoint(dirname='exp2')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
MyModule, lr=0.5,
callbacks=[cp, load_state],
module__num_units=20,
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
1 0.5256 0.8856 0.3624 + 0.0181
2 0.2956 0.8756 0.3416 + 0.0222
3 0.2280 0.9453 0.2299 + 0.0211
4 0.1948 0.9303 0.2136 + 0.0232
5 0.1800 0.9055 0.2696 0.0223
6 0.1605 0.9403 0.1906 + 0.0190
7 0.1594 0.9403 0.2027 0.0184
8 0.1319 0.9303 0.1910 0.0220
9 0.1558 0.9254 0.1923 0.0189
10 0.1432 0.9303 0.2219 0.0192
This stores the model into the 'exp2'
directory. Since this is the
first run, the LoadInitState
callback does not do anything.
If we were to run the above script again, the LoadInitState
callback will load the model from the checkpoint.
In the run above, the last checkpoint was created at epoch 6, we can load this checkpoint to predict with it:
net = NeuralNetClassifier(
MyModule, lr=0.5, module__num_units=20,
)
net.initialize()
net.load_params(checkpoint=cp)
y_pred = net.predict(X)
In this case, it is important to initialize the neural net before
running NeuralNet.load_params()
.
Gaussian Processes¶
skorch integrates with GPyTorch to make it easy to train Gaussian Process (GP) models. You should already know how Gaussian Processes work. Please refer to other resources if you want to learn about them, this section assumes familiarity with the concept.
GPyTorch adopts many patterns from PyTorch, thus making it easy to pick up for seasoned PyTorch users. Similarly, the skorch GPyTorch integration should look familiar to seasoned skorch users. However, GPs are a different beast than the more common, non-probabilistic machine learning techniques. It is important to understand the basic concepts before using them in practice.
Installation¶
In addition to the normal skorch dependencies and PyTorch, you need to install GPyTorch as well. It wasn’t added as a normal dependency since most users probably are not interested in using skorch for GPs. To install GPyTorch, use either pip or conda:
# using pip
python -m pip install -U gpytorch
# using conda
conda install gpytorch -c gpytorch
When to use GPyTorch with skorch¶
Here we want to quickly explain when it would be a good idea for you to use GPyTorch with skorch. There are a couple of offerings in the Python ecosystem when it comes to Gaussian Processes. We cannot provide an exhaustive list of pros and cons of each possibility. There are, however, two obvious alternatives that are worth discussing: using the sklearn implementation and using GPyTorch without skorch.
When to use skorch + GPyTorch over sklearn:
- When you are more familiar with PyTorch than with sklearn
- When the kernels provided by sklearn are not sufficient for your use case and you would like to implement custom kernels with PyTorch
- When you want to use the rich set of optimizers available in PyTorch
- When sklearn is too slow and you want to use the GPU or scale across machines
- When you like to use the skorch extras, e.g. callbacks
When to use skorch + GPyTorch over pure GPyTorch
- When you’re already familiar with skorch and want an easy entry into GPs
- When you like to use the skorch extras, e.g. callbacks and grid search
- When you don’t want to bother with writing your own training loop
However, if you are researching GPs and would like to have control over every detail, using all the rich but very specific featues that GPyTorch has on offer, it is better to use it directly without skorch.
Examples¶
Exact Gaussian Processes¶
Same as GPyTorch, skorch supports exact and approximate Gaussian Processes
regression. For exact GPs, use the
ExactGPRegressor
. The likelihood has to be a
GaussianLikelihood
and the criterion
ExactMarginalLogLikelihood
, but those are the defaults
and thus don’t need to be specified. For exact GPs, the module needs to be an
ExactGP
. For this example, we use a simple RBF kernel.
import gpytorch
from skorch.probabilistic import ExactGPRegressor
class RbfModule(gpytorch.models.ExactGP):
def __init__(likelihood, self):
# detail: We don't set train_inputs and train_targets here skorch because
# will take care of that.
super().__init__()
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.RBFKernel()
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
gpr = ExactGPRegressor(RbfModule)
gpr.fit(X_train, y_train)
y_pred = gpr.predict(X_test)
As you can see, this almost looks like a normal skorch regressor with a normal
PyTorch module. We can fit as normal using the fit
method and predict using
the predict
method.
Inside the module, we determine the mean by using a mean function (just constant
in this case) and the covariance matrix using the RBF kernel function. You
should know about mean and kernel functions already. Having the mean and
covariance matrix, we assume that the output distribution is a multivariate
normal function, since exact GPs rely on this assumption. We could send the
x
through an MLP for Deep Kernel Learning
but left it out to keep the example simple.
One major difference to usual deep learning models is that we actually predict a distribution, not just a point estimate. That means that if we choose an appropriate model that fits the data well, we can express the uncertainty of the model:
y_pred, y_std = gpr.predict(X, return_std=True)
lower_conf_region = y_pred - y_std
upper_conf_region = y_pred + y_std
Here we not only returned the mean of the prediction, y_pred
, but also its
standard deviation, y_std
. This tells us how uncertain the model is about
its prediction. E.g., it could be the case that the model is fairly certain when
interpolating between data points but uncertain about extrapolating. This is
not possible to know when models only learn point predictions.
The obtain the confidence region, you can also use the confidence_region
method:
# 1 standard deviation
lower, upper = gpr.confidence_region(X, sigmas=1)
# 2 standard deviation, the default
lower, upper = gpr.confidence_region(X, sigmas=2)
Furthermore, a GP allows you to sample from the distribution even before fitting it. The GP needs to be initialized, however:
gpr = ExactGPRegressor(...)
gpr.initialize()
samples = gpr.sample(X, n_samples=100)
By visualizing the samples and comparing them to the true underlying distribution of the target, you can already get a feel about whether the model you built is capable of generating the distribution of the target. If fitting takes a long time, it is therefore recommended to check the distribution first, otherwise you may try to fit a model that is incapable of generating the true distribution and waste a lot of time.
Approximate Gaussian Processes¶
For some situations, fitting an exact GP might be infeasible, e.g. because the
distribution is not Gaussian or because you want to perform stochastic
optimization with mini-batches. For this, GPyTorch provides facilities to train
variational and approximate GPs. The module should inherit from
ApproximateGP
and should define a variational
strategy. From the skorch side of things, use
GPRegressor
.
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from skorch.probabilistic import GPRegressor
class VariationalModule(ApproximateGP):
def __init__(self, inducing_points):
variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
variational_strategy = VariationalStrategy(
self, inducing_points, variational_distribution, learn_inducing_locations=True,
)
super().__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
X, y = get_data(...)
X_incuding = X[:100]
X_train, y_train = X[100:], y[100:]
num_training_samples = len(X_train)
gpr = GPRegressor(
VariationalModule,
module__inducing_points=X_inducing,
criterion__num_data=num_training_samples,
)
gpr.fit(X_train, y_train)
y_pred = gpr.predict(X_train)
As you can see, the variational strategy requires us to use inducing points. We split off 100 of our training data samples to use as inducing points, assuming that they are representative of the whole distribution. Apart from this, there is basically no difference to using exact GP regression.
Finally, skorch also provides GPBinaryClassifier
for binary classification with GPs. It uses a Bernoulli likelihood by default.
However, using GPs for classification is not very common, GPs are most commonly
used for regression tasks where data points have a known relationship to each
other (e.g. in time series forecasts).
Multiclass classification is not currently provided, but you can use
GPBinaryClassifier
in conjunction with
OneVsRestClassifier
to achieve the same result.
Further examples¶
To see all of this in action, we provide a notebook that shows using skorch with GPs on real world data: Gaussian Processes notebook.
History¶
A NeuralNet
object logs training progress internally using a
History
object, stored in the history
attribute. Among
other use cases, history
is used to print the training progress
after each epoch:
net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7111 0.5100 0.6894 0.1345
2 0.6928 0.5500 0.6803 0.0608
3 0.6833 0.5650 0.6741 0.0620
4 0.6763 0.5850 0.6674 0.0594
All this information (and more) is stored in and can be accessed
through net.history
. It is thus best practice to make use of
history
for storing training-related data.
In general, History
works like a list of dictionaries, where
each item in the list corresponds to one epoch, and each key of the
dictionary to one column. Thus, if you would like to access the
'train_loss'
of the last epoch, you can call
net.history[-1]['train_loss']
. To make the history
more
accessible, though, it is possible to just pass the indices separated
by a comma: net.history[-1, 'train_loss']
.
Moreover, History
stores the results from each individual
batch under the batches
key during each epoch. So to get the train
loss of the 3rd batch of the 7th epoch, use net.history[7,
'batches', 3, 'train_loss']
.
Here are some examples showing how to index history
:
# history of a fitted neural net
history = net.history
# get current epoch, a dict
history[-1]
# get train losses from all epochs, a list of floats
history[:, 'train_loss']
# get train and valid losses from all epochs, a list of tuples
history[:, ('train_loss', 'valid_loss')]
# get current batches, a list of dicts
history[-1, 'batches']
# get latest batch, a dict
history[-1, 'batches', -1]
# get train losses from current batch, a list of floats
history[-1, 'batches', :, 'train_loss']
# get train and valid losses from current batch, a list of tuples
history[-1, 'batches', :, ('train_loss', 'valid_loss')]
As History
essentially is a list of dictionaries, you can
also write to it as if it were a list of dictionaries. Here too,
skorch provides some convenience functions to make life easier. First
there is new_epoch()
, which will add a
new epoch dictionary to the end of the list. Also, there is
new_batch()
for adding new batches to
the current epoch.
To add a new item to the current epoch, use history.record('foo',
123)
. This will set the value 123
for the key foo
of the
current epoch. To write a value to the current batch, use
history.record_batch('bar', 456)
. Below are some more examples:
# history of a fitted neural net
history = net.history
# add new epoch row
history.new_epoch()
# add an entry to current epoch
history.record('my-score', 123)
# add a batch row to the current epoch
history.new_batch()
# add an entry to the current batch
history.record_batch('my-batch-score', 456)
# overwrite entry of current batch
history.record_batch('my-batch-score', 789)
Toy¶
This module contains helper functions and classes that allow you to prototype quickly or that can be used for writing tests.
MLPModule¶
MLPModule
is a simple PyTorch Module
that
implements a multi-layer perceptron. It allows to indicate the number
of input, hidden, and output units, as well as the non-linearity and
use of dropout. You can use this module directly in conjunction with
NeuralNet
.
Additionally, the functions make_classifier()
,
make_binary_classifier()
, and
make_regressor()
can be used to return a
MLPModule
with the defaults adjusted for use in multi-class
classification, binary classification, and regression, respectively.
Helper¶
This module provides helper functions and classes for the user. They make working with skorch easier but are not used by skorch itself.
SliceDict¶
A SliceDict
is a wrapper for Python dictionaries that makes
them behave a little bit like numpy.ndarray
s. That way, you
can slice your dictionary across values, len()
will show the
length of the arrays and not the number of keys, and you get a
shape
attribute. This is useful because if your data is in a
dict
, you would normally not be able to use sklearn
GridSearchCV
and similar things;
with SliceDict
, this works.
SliceDataset¶
A SliceDataset
is a wrapper for
PyTorch Dataset
s that makes them behave a little
bit like numpy.ndarray
s. That way, you can slice your
dataset with lists and arrays, and you get a shape
attribute.
These properties are useful because if your data is in a dataset, you
would normally not be able to use sklearn
GridSearchCV
and similar things;
with SliceDataset
, this works.
Note that SliceDataset
can only ever return one of the
values returned by the dataset. Typically, this will be either the X
or the y value. Therefore, if you want to wrap both X and y, you
should create two instances of SliceDataset
, one for X (with
argument idx=0
, the default) and one for y (with argument
idx=1
):
ds = MyCustomDataset()
X_sl = SliceDataset(ds, idx=0) # idx=0 is the default
y_sl = SliceDataset(ds, idx=1)
gs.fit(X_sl, y_sl)
Command line interface helpers¶
Often you want to wrap up your experiments by writing a small script that allows others to reproduce your work. With the help of skorch and the fire library, it becomes very easy to write command line interfaces without boilerplate. All arguments pertaining to skorch or its PyTorch module are immediately available as command line arguments, without the need to write a custom parser. If docstrings in the numpydoc specification are available, there is also an comprehensive help for the user. Overall, this allows you to make your work reproducible without the usual hassle.
There is an example in the skorch repository that shows how to use the CLI tools. Below is a snippet that shows the output created by the help function without writing a single line of argument parsing:
$ python examples/cli/train.py pipeline --help
<SelectKBest> options:
--select__score_func : callable
Function taking two arrays X and y, and returning a pair of arrays
(scores, pvalues) or a single array with scores.
Default is f_classif (see below "See also"). The default function only
works with classification tasks.
--select__k : int or "all", optional, default=10
Number of top features to select.
The "all" option bypasses selection, for use in a parameter search.
...
<NeuralNetClassifier> options:
--net__module : torch module (class or instance)
A PyTorch :class:`~torch.nn.Module`. In general, the
uninstantiated class should be passed, although instantiated
modules will also work.
--net__criterion : torch criterion (class, default=torch.nn.NLLLoss)
Negative log likelihood loss. Note that the module should return
probabilities, the log is applied during ``get_loss``.
--net__optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the
module
--net__lr : float (default=0.01)
Learning rate passed to the optimizer. You may use ``lr`` instead
of using ``optimizer__lr``, which would result in the same outcome.
--net__max_epochs : int (default=10)
The number of epochs to train for each ``fit`` call. Note that you
may keyboard-interrupt training at any time.
--net__batch_size : int (default=128)
...
--net__verbose : int (default=1)
Control the verbosity level.
--net__device : str, torch.device (default='cpu')
The compute device to be used. If set to 'cuda', data in torch
tensors will be pushed to cuda tensors before being sent to the
module.
<MLPClassifier> options:
--net__module__hidden_units : int (default=10)
Number of units in hidden layers.
--net__module__num_hidden : int (default=1)
Number of hidden layers.
--net__module__nonlin : torch.nn.Module instance (default=torch.nn.ReLU())
Non-linearity to apply after hidden layers.
--net__module__dropout : float (default=0)
Dropout rate. Dropout is applied between layers.
Installation¶
To use this functionality, you need some further libraries that are not part of skorch, namely fire and numpydoc. You can install them thusly:
python -m pip install fire numpydoc
Usage¶
When you write your own script, only the following bits need to be added:
import fire
from skorch.helper import parse_args
# your model definition and data fetching code below
...
def main(**kwargs):
X, y = get_data()
my_model = get_model()
# important: wrap the model with the parsed arguments
parsed = parse_args(kwargs)
my_model = parsed(my_model)
my_model.fit(X, y)
if __name__ == '__main__':
fire.Fire(main)
This even works if your neural net is part of an sklearn pipeline, in which case the help extends to all other estimators of your pipeline.
In case you would like to change some defaults for the net (e.g. using
a batch_size
of 256 instead of 128), this is also possible. You
should have a dictionary containing your new defaults and pass it as
an additional argument to parse_args
:
my_defaults = {'batch_size': 128, 'module__hidden_units': 30}
def main(**kwargs):
...
parsed = parse_args(kwargs, defaults=my_defaults)
my_model = parsed(my_model)
This will update the displayed help to your new defaults, as well as
set the parameters on the net or pipeline for you. However, the
arguments passed via the commandline have precedence. Thus, if you
additionally pass --batch_size 512
to the script, batch size will
be 512.
Restrictions¶
Almost all arguments should work out of the box. Therefore, you get
command line arguments for the number of epochs, learning rate, batch
size, etc. for free. Moreover, you can access the module parameters
with the double-underscore notation as usual with skorch
(e.g. --module__num_units 100
). This should cover almost all
common cases.
Parsing command line arguments that are non-primitive Python objects
is more difficult, though. skorch’s custom parsing should support
normal Python types and simple custom objects, e.g. this works:
--module__nonlin 'torch.nn.RReLU(0.1, upper=0.4)'
. More complex
parsing might not work. E.g., it is currently not possible to add new
callbacks through the command line (but you can modify existing ones
as usual).
REST Service¶
In this section we’ll take the RNN sentiment classifer from the example Predicting sentiment on the IMDB dataset and use it to demonstrate how to easily expose your PyTorch module on the web using skorch and another library called Palladium.
With Palladium, you define the Palladium dataset, the model, and Palladium provides the framework to fit, test, and serve your model on the web. Palladium comes with its own documentation and a tutorial, which you may want to check out to learn more about what you can do with it.
The way to make the dataset and model known to Palladium is through its configuration file. Here’s the part of the configuration that defines the dataset and model:
{
'dataset_loader_train': {
'__factory__': 'model.DatasetLoader',
'path': 'aclImdb/train/',
},
'dataset_loader_test': {
'__factory__': 'model.DatasetLoader',
'path': 'aclImdb/test/',
},
'model': {
'__factory__': 'model.create_pipeline',
'use_cuda': True,
},
'model_persister': {
'__factory__': 'palladium.persistence.File',
'path': 'rnn-model-{version}',
},
'scoring': 'accuracy',
}
You can save this configuration as palladium-config.py
.
The dataset_loader_train
and dataset_loader_test
entries
define where the data comes from. They refer to a Python class
defined inside the model
module. Let’s create a file and call it
model.py
, put it in the same directory as the configuration file.
We’ll start off with defining the dataset loader:
import os
from urllib.request import urlretrieve
import tarfile
import numpy as np
from palladium.interfaces import DatasetLoader as IDatasetLoader
from sklearn.datasets import load_files
DATA_URL = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
DATA_FN = DATA_URL.rsplit('/', 1)[1]
def download():
if not os.path.exists('aclImdb'):
# unzip data if it does not exist
if not os.path.exists(DATA_FN):
urlretrieve(DATA_URL, DATA_FN)
with tarfile.open(DATA_FN, 'r:gz') as f:
f.extractall()
class DatasetLoader(IDatasetLoader):
def __init__(self, path='aclImdb/train/'):
self.path = path
def __call__(self):
download()
dataset = load_files(self.path, categories=['pos', 'neg'])
X, y = dataset['data'], dataset['target']
X = np.asarray([x.decode() for x in X]) # decode from bytes
return X, y
The most interesting bit here is that our Palladium DatasetLoader
defines a __call__
method that will return the data and the target
(X and y). Easy. Note that in the configuration file, we refer to
our DatasetLoader
twice, once for the training set and once for
the test set.
Our configuration also refers to a function create_pipeline
which
we’ll create next:
from dstoolbox.transformers import Padder2d
from dstoolbox.transformers import TextFeaturizer
from sklearn.pipeline import Pipeline
from skorch import NeuralNetClassifier
import torch
def create_pipeline(
vocab_size=1000,
max_len=50,
use_cuda=False,
**kwargs
):
return Pipeline([
('to_idx', TextFeaturizer(max_features=vocab_size)),
('pad', Padder2d(max_len=max_len, pad_value=vocab_size, dtype=int)),
('net', NeuralNetClassifier(
RNNClassifier,
device=('cuda' if use_cuda else 'cpu'),
max_epochs=5,
lr=0.01,
optimizer=torch.optim.RMSprop,
module__vocab_size=vocab_size,
**kwargs,
))
])
You’ve noticed that this function’s job is to create the model and
return it. Here, we’re defining a pipeline that wraps skorch’s
NeuralNetClassifier
, which in turn is a wrapper around our PyTorch
module, as it’s defined in the predicting sentiment tutorial.
We’ll also add the RNNClassifier to model.py
:
from torch import nn
F = nn.functional
class RNNClassifier(nn.Module):
def __init__(
self,
embedding_dim=128,
rec_layer_type='lstm',
num_units=128,
num_layers=2,
dropout=0,
vocab_size=1000,
):
super().__init__()
self.embedding_dim = embedding_dim
self.rec_layer_type = rec_layer_type.lower()
self.num_units = num_units
self.num_layers = num_layers
self.dropout = dropout
self.emb = nn.Embedding(
vocab_size + 1, embedding_dim=self.embedding_dim)
rec_layer = {'lstm': nn.LSTM, 'gru': nn.GRU}[self.rec_layer_type]
# We have to make sure that the recurrent layer is batch_first,
# since sklearn assumes the batch dimension to be the first
self.rec = rec_layer(
self.embedding_dim, self.num_units,
num_layers=num_layers, batch_first=True,
)
self.output = nn.Linear(self.num_units, 2)
def forward(self, X):
embeddings = self.emb(X)
# from the recurrent layer, only take the activities from the
# last sequence step
if self.rec_layer_type == 'gru':
_, rec_out = self.rec(embeddings)
else:
_, (rec_out, _) = self.rec(embeddings)
rec_out = rec_out[-1] # take output of last RNN layer
drop = F.dropout(rec_out, p=self.dropout)
# Remember that the final non-linearity should be softmax, so
# that our predict_proba method outputs actual probabilities!
out = F.softmax(self.output(drop), dim=-1)
return out
You can find the full contents of the model.py
file in the
skorch/examples/rnn_classifer
folder of skorch’s source code.
Now with dataset and model in place, it’s time to try Palladium out.
You can install Palladium and another dependency we use with pip
install palladium dstoolbox
.
From within the directory that contains model.py
and
palladium-config.py
now run the following command:
PALLADIUM_CONFIG=palladium-config.py pld-fit --evaluate
You should see output similar to this:
INFO:palladium:Loading data...
INFO:palladium:Loading data done in 0.607 sec.
INFO:palladium:Fitting model...
epoch train_loss valid_acc valid_loss dur
------- ------------ ----------- ------------ ------
1 0.7679 0.5008 0.7617 3.1300
2 0.6385 0.7100 0.5840 3.1247
3 0.5430 0.7438 0.5518 3.1317
4 0.4736 0.7480 0.5424 3.1373
5 0.4253 0.7448 0.5832 3.1433
INFO:palladium:Fitting model done in 29.060 sec.
DEBUG:palladium:Evaluating model on train set...
INFO:palladium:Train score: 0.83068
DEBUG:palladium:Evaluating model on train set done in 6.743 sec.
DEBUG:palladium:Evaluating model on test set...
INFO:palladium:Test score: 0.75428
DEBUG:palladium:Evaluating model on test set done in 6.476 sec.
INFO:palladium:Writing model...
INFO:palladium:Writing model done in 0.694 sec.
INFO:palladium:Wrote model with version 1.
Congratulations, you’ve trained your first model with Palladium! Note
that in the output you see a train score (accuracy) of 0.83 and a test
score of about 0.75. These refer to how well your model did on the
training set (defined by dataset_loader_train
in the
configuration) and on the test set (dataset_loader_test
).
You’re ready to now serve the model on the web. Add this piece of
configuration to the palladium-config.py
configuration file (and
make sure it lives within the outermost brackets:
{
# ...
'predict_service': {
'__factory__': 'palladium.server.PredictService',
'mapping': [
('text', 'str'),
],
'predict_proba': True,
'unwrap_sample': True,
},
# ...
}
With this piece of information inside the configuration, we’re ready to launch the web server using:
PALLADIUM_CONFIG=palladium-config.py pld-devserver
You can now try out the web service at this address: http://localhost:5000/predict?text=this+movie+was+brilliant
You should see a JSON string returned that looks something like this:
{
"metadata": {"error_code": 0, "status": "OK"},
"result": [0.326442807912827, 0.673557221889496],
}
The result
entry has the probabilities. Our model assigns 67%
probability to the sentence “this movie was brilliant” to be positive.
By the way, the skorch tutorial itself has tips on how to improve this
model.
The take away is Palladium helps you reduce the boilerplate code that’s needed to get your machine learning project started. Palladium has routines to fit, test, and serve models so you don’t have to worry about that, and you can concentrate on the actual machine learning part. Configuration and code are separated with Palladium, which helps organize your experiments and work on ideas in parallel. Check out the Palladium documentation for more.
Parallelism¶
Skorch supports distributing work among a cluster of workers via dask.distributed. In this section we’ll describe how to use Dask to efficiently distribute a grid search or a randomized search on hyperparamerers across multiple GPUs and potentially multiple hosts.
Let’s assume that you have two GPUs that you want to run a hyper parameter search on.
The key here is using the CUDA environment variable
CUDA_VISIBLE_DEVICES
to limit which devices are visible to our CUDA application. We’ll set
up Dask workers that, using this environment variable, each see one
GPU only. On the PyTorch side, we’ll have to make sure to set the
device to cuda
when we initialize the NeuralNet
class.
Let’s run through the steps. First, install Dask and dask.distributed:
python -m pip install dask distributed
Next, assuming you have two GPUs on your machine, let’s start up a Dask scheduler and two Dask workers. Make sure the Dask workers are started up in the right environment, that is, with access to all packages required to do the work:
dask-scheduler
CUDA_VISIBLE_DEVICES=0 dask-worker 127.0.0.1:8786 --nthreads 1
CUDA_VISIBLE_DEVICES=1 dask-worker 127.0.0.1:8786 --nthreads 1
In your code, use joblib’s parallel_backend()
context
manager to activate the Dask backend when you run grid searches and
the like. Also instantiate a dask.distributed.Client
to
point to the Dask scheduler that you want to use. Let’s see how this
could look like:
from dask.distributed import Client
from joblib import parallel_backend
client = Client('127.0.0.1:8786')
X, y = load_my_data()
net = get_that_net()
gs = GridSearchCV(
net,
param_grid={'lr': [0.01, 0.03]},
scoring='accuracy',
)
with parallel_backend('dask'):
gs.fit(X, y)
print(gs.cv_results_)
You can also use Palladium to do
the job. An example is included in the source in the
examples/rnn_classifier
folder. Change in there and run the
following command, after having set up your Dask workers:
PALLADIUM_CONFIG=palladium-config.py,dask-config.py pld-grid-search
Customization¶
Customizing NeuralNet¶
Apart from the NeuralNet
base class, we provide
NeuralNetClassifier
, NeuralNetBinaryClassifier
,
and NeuralNetRegressor
for typical classification, binary
classification, and regressions tasks. They should work as drop-in
replacements for sklearn classifiers and regressors.
The NeuralNet
class is a little less opinionated about the
incoming data, e.g. it does not determine a loss function by default.
Therefore, if you want to write your own subclass for a special use
case, you would typically subclass from NeuralNet
. The
predict()
method returns the same output
as predict_proba()
by default, which is
the module output (or the first module output, in case it returns
multiple values).
NeuralNet
and its subclasses are already very flexible as they are and
should cover many use cases by adjusting the provided parameters or by using
callbacks. However, this may not always be sufficient for your use cases. If you
thus find yourself wanting to customize NeuralNet
, please follow the
guidelines in this section.
Methods starting with get_*¶
The net provides a few get_*
methods, most notably
get_loss()
,
get_dataset()
, and
get_iterator()
. The intent of these
methods should be pretty self-explanatory, and if you are still not
quite sure, consult their documentations. In general, these methods
are fairly safe to override as long as you make sure to conform to the
same signature as the original.
A short example should serve to illustrate this.
get_loss()
is called when the loss is determined.
Below we show an example of overriding get_loss()
to
add L1 regularization to our total loss:
class RegularizedNet(NeuralNet):
def __init__(self, *args, lambda1=0.01, **kwargs):
super().__init__(*args, **kwargs)
self.lambda1 = lambda1
def get_loss(self, y_pred, y_true, X=None, training=False):
loss = super().get_loss(y_pred, y_true, X=X, training=training)
loss += self.lambda1 * sum([w.abs().sum() for w in self.module_.parameters()])
return loss
Note
This example also regularizes the biases, which you typically don’t need to do.
It is often a good idea to call super
of the method you override, to make
sure that everything that needs to happen inside that method does happen. If you
don’t, you should make sure to take care of everything that needs to happen by
following the original implementation.
Training and validation¶
If you would like to customize training and validation, there are several possibilities. Below are the methods that you most likely want to customize:
The method train_step_single()
performs a
single training step. It accepts the current batch of data as input
(as well as the fit_params
) and should return a dictionary
containing the loss
and the prediction y_pred
. E.g. you should
override this if your dataset returns some non-standard data that
needs custom handling, and/or if your module has to be called in a
very specific way. If you want to, you can still make use of
infer()
and
get_loss()
but it’s not strictly
necessary. Don’t call the optimizer in this method, this is handled by
the next method.
The method train_step()
defines the
complete training procedure performed for each batch. It accepts the
same arguments as train_step_single()
but
it differs in that it defines the training closure passed to the
optimizer, which for instance could be called more than once (e.g. in
case of L-BFGS). You might override this if you deal with non-standard
training procedures, as e.g. gradient accumulation.
The method validation_step()
is
responsible for calculating the prediction and loss on the validation
data (remember that skorch uses an internal validation set for
reporting, early stopping, etc.). Similar to
train_step_single()
, it receives the batch
and fit_params
as input and should return a dictionary containing
loss
and y_pred
. Most likely, when you need to customize
train_step_single()
, you’ll need to
customize validation_step()
accordingly.
Finally, the method evaluation_step()
is
called to you use the net for inference, e.g. when calling
forward()
or
predict()
. You may want to modify this if,
e.g., you want your model to behave differently during training and
during prediction.
You should also be aware that some methods are better left untouched. E.g., in most cases, the following methods should not be overridden:
The reason why these methods should stay untouched is because they
perform some book keeping, like making sure that callbacks are handled
or writing logs to the history
. If you do need to override these,
make sure that you perform the same book keeping as the original
methods.
Initialization and custom modules¶
The method initialize()
is responsible for
initializing all the components needed by the net, e.g. the module and
the optimizer. For this, it calls specific initialization methods,
such as initialize_module()
and
initialize_optimizer()
. If you’d like to
customize the initialization behavior, you should override the
corresponding methods. Following sklearn conventions, the created
components should be set as an attribute with a trailing underscore as
the name, e.g. module_
for the initialized module.
A possible modification you may want to make is to add more modules, criteria, and optimizers to your net. This is possible in skorch by following the guidelines below. If you do this, your custom modules and optimizers will be treated as “first class citizens” in skorch land. This means:
- The parameters of your custom modules are automatically passed to the optimizer (but you can modify this behavior).
- skorch takes care of moving your modules to the correct device.
- skorch takes care of setting the training/eval mode correctly.
- When a module needs to be re-initialized because
set_params
was called, all modules and optimizers that may depend on it are also re-initialized. This is for instance important for the optimizer, which must know about the parameters of the newly initialized module. - You can pass arguments to the custom modules and optimizers using the now familiar double-underscore notation. E.g., you can initialize your net like this:
net = MyNet(
module=MyModule,
module__num_units=100,
othermodule=MyOtherModule,
othermodule__num_units=200,
)
net.fit(X, y)
A word about the distinction between modules and criteria made by skorch:
Typically, criteria are also just subclasses of PyTorch
Module
. As such, skorch moves them to CUDA if that is the
indicated device and will even pass parameters of criteria to the optimizers, if
there are any. This can be useful when e.g. training GANs, where you might
implement the discriminator as the criterion (and the generator as the module).
A difference between module and criterion is that the output of modules are used
for generating the predictions and are thus returned by
predict()
etc. In contrast, the output of the
criterion is used for calculating the loss and should therefore be a scalar.
skorch assumes that criteria may depend on the modules. Therefore, if a module is re-initialized, all criteria are also re-initialized, but not vice-versa. On top of that, the optimizer is re-initialized when either modules or criteria are changed.
So after all this talk, what are the aforementioned guidelines to add your own modules, criteria, and optimizers? You have to follow these rules:
- Initialize them during their respective
initialize_
methods, e.g. modules should be set insideinitialize_module()
. - If they have learnable parameters, they should be instances of
Module
. Optimizers should be instances ofOptimizer
. - Their names should end on an underscore. This is true for all attributes that
are created during
initialize
and distinguishes them from arguments passed to__init__
. So a name for a custom module could bemymodule_
. - Inside the initialization method, use
skorch.net.NeuralNet.get_params_for()
(or, if dealing with an optimizer,skorch.net.NeuralNet.get_params_for_optimizer()
) to retrieve the arguments for the constructor of the instance.
Here is an example of how this could look like in practice:
class MyNet(NeuralNet):
def initialize_module(self):
super().initialize_module()
# add an additional module called 'module2_'
params = self.get_params_for('module2')
self.module2_ = Module2(**params)
return self
def initialize_criterion(self):
super().initialize_criterion()
# add an additional criterion called 'other_criterion_'
params = self.get_params_for('other_criterion')
self.other_criterion_ = nn.BCELoss(**params)
return self
def initialize_optimizer(self):
# first initialize the normal optimizer
named_params = self.module_.named_parameters()
args, kwargs = self.get_params_for_optimizer('optimizer', named_params)
self.optimizer_ = self.optimizer(*args, **kwargs)
# next add an another optimizer called 'optimizer2_' that is
# only responsible for training 'module2_'
named_params = self.module2_.named_parameters()
args, kwargs = self.get_params_for_optimizer('optimizer2', named_params)
self.optimizer2_ = torch.optim.SGD(*args, **kwargs)
return self
... # additional changes
net = MyNet(
...,
module2__num_units=123,
other_criterion__reduction='sum',
optimizer2__lr=0.1,
)
net.fit(X, y)
# set_params works
net.set_params(optimizer2__lr=0.05)
net.partial_fit(X, y)
# grid search et al. works
search = GridSearchCV(net, {'module2__num_units': [10, 50, 100]}, ...)
search.fit(X, y)
In this example, a new criterion, a new module, and a new optimizer were added. Of course, additional changes should be made to the net so that those new components are actually being used for something, but this example should illustrate how to start. Since the rules outlined above are being followed, we can use grid search on our customly defined components.
Note
In the example above, the parameters of module_
are trained by
optimzer_
and the parameters of module2_
are trained by
optimizer2_
. To conveniently obtain the parameters of all modules,
call the method get_all_learnable_params()
.
Performance¶
Since skorch provides extra functionality on top of a pure PyTorch training code, it is expected that it will add an overhead to the total runtime. For typical workloads, this overhead should be unnoticeable.
In a few situations, skorch’s extra functionality may add significant overhead.
This is especially the case when the amount of data and the neural net are
relatively small. The reason is that typically, most time is spent on the
forward
, backward
, and parameter update calls. When those are really
fast, the skorch overhead will get noticed.
There are, however, a few things that can be done to reduce the skorch overhead. We will focus on accelerating the training process, where the overhead should be largest. Below, some mitigations are described, including the potential downsides.
First make sure that there is any significant slowdown¶
Neural nets are notoriously slow to train. Therefore, if your training takes a lot of time, that doesn’t automatically mean that the skorch overhead is at fault. Maybe the training would take the same time without skorch. If you have some measurements about training the same model without skorch, first make sure that this points to skorch being the culprit before trying to optimize using the mitigations described below. If it turns out skorch is not the culprit, look into optimizing the performance of PyTorch code in general.
Many people use skorch for hyper-parameter search. Remember that this implies
fitting the model repeatedly, thus a long run time is expected. E.g. if you run
a grid search on two hyper-parameters, each with 10 variants, and 5 splits,
there will actually be 10 x 10 x 5 fit calls, so expect the process to take
approximately 500 times as long as a single model fit. Increase the verbosity on
the grid search to get a better idea on the progress (e.g. GridSerachCV(...,
verbose=3)
).
Turning off verbosity¶
By default, skorch produces a print log of the training progress. This is useful
for checking the training progress, monitor overfitting, etc. If you don’t need
these diagnostics, you can turn them off via the verbose
parameter. This
way, printing is deactivated, saving time on i/o. You can still access the
diagnostics through the history
attribute after training has finished.
net = NeuralNet(..., verbose=0) # turn off verbosity
net.fit(X, y)
train_loss = net.history[..., 'train_loss'] # access history as usual
Disabling callbacks all together¶
If you don’t need any callbacks at all, turning them off can be potential time
saver. Callbacks present the most significant “extra” that skorch provides over
pure PyTorch, hence they might add a lot of overhead for small workloads. By
turning them off, you lose their functionality, though. It’s up to you to
determine if that’s a worthwhile trade-off or not. For instance, in contrast to
just turning down verbosity, you will no longer have access to useful
diagnostics in the history
attribute.
# skorch version 0.10 or later:
net = NeuralNet(..., callbacks='disable')
net.fit(X, y)
print(net.history) # no longer contains useful diagnostics
# skorch version 0.9 or earlier
net = NeuralNet(...)
net.initialize()
net.callbacks_ = [] # manually remove all callbacks
net.fit(X, y)
Instead of turning off all callbacks, you can also turn off specific callbacks, including default callbacks. This way, you can decide which ones to keep and which ones to get rid off. Typically, callbacks that calculate some kind of metric tend to be slow.
# deactivate callbacks that determine train and valid loss after each epoch
net = NeuralNet(..., callbacks__train_loss=None, callbacks__valid_loss=None)
net.fit(X, y)
print(net.history) # no longer contains 'train_loss' and 'valid_loss' entries
Prepare the Dataset¶
skorch can deal with a number of different input data types. This is very
convenient, as it removes the necessity for the user to deal with them, but it
also adds a small overhead. Therefore, if you can prepare your data so that it’s
already contained in an appropriate torch.utils.data.Dataset
, this
check can be skipped.
X, y = ... # let's assume that X and y are numpy arrays
net = NeuralNet(...)
# normal way: let skorch figure out how to create the Dataset
net.fit(X, y)
# faster way: prepare Dataset yourself
from torch.utils.data import TensorDataset
Xt = torch.from_numpy(X)
yt = torch.from_numpy(y)
tensor_ds = TensorDataset(Xt, yt)
net.fit(tensor_ds, None)
Still too slow¶
You find your skorch code still to be slow despite trying all of these tips, and you made sure that the slowdown is indeed caused by skorch. What can you do now? In this case, please search our issue tracker for solutions or open a new issue. Provide as much context as possible and, if available, a minimal code example. We will try to help you figure out what the problem is.
Hugging Face Integration¶
skorch integrates with some libraries from the Hugging Face ecosystem. Take a look at the sections below to learn more.
Accelerate¶
The AccelerateMixin
class can be used to add support for huggingface
accelerate to skorch. E.g., this allows you to use mixed precision training
(AMP), multi-GPU training, training with a TPU, or gradient accumulation. For the
time being, this feature should be considered experimental.
To use this feature, create a new subclass of the neural net class you want to
use and inherit from the mixin class. E.g., if you want to use a
NeuralNet
, it would look like this:
from skorch import NeuralNet
from skorch.hf import AccelerateMixin
class AcceleratedNet(AccelerateMixin, NeuralNet):
"""NeuralNet with accelerate support"""
The same would work for NeuralNetClassifier
,
NeuralNetRegressor
, etc. Then pass an instance of Accelerator with
the desired parameters and you’re good to go:
from accelerate import Accelerator
accelerator = Accelerator(...)
net = AcceleratedNet(
MyModule,
accelerator=accelerator,
)
net.fit(X, y)
accelerate recommends to leave the device handling to the Accelerator, which
is why device
defautls to None
(thus telling skorch not to change the
device).
Models using AccelerateMixin
cannot be pickled. If you need to save
and load the net, either use skorch.net.NeuralNet.save_params()
and skorch.net.NeuralNet.load_params()
.
To install accelerate, run the following command inside your Python environment:
python -m pip install accelerate
Note
Under the hood, accelerate uses GradScaler
,
which does not support passing the training step as a closure.
Therefore, if your optimizer requires that (e.g.
torch.optim.LBFGS
), you cannot use accelerate.
Tokenizers¶
skorch also provides sklearn-like transformers that work with Hugging Face
tokenizers. The transform
methods of these transformers return data in a dict-like data structure, which
makes them easy to use in conjunction with skorch’s NeuralNet
. Below
is an example of how to use a pretrained tokenizer with the help of
skorch.hf.HuggingfacePretrainedTokenizer
:
from skorch.hf import HuggingfacePretrainedTokenizer
# pass the model name to be downloaded
hf_tokenizer = HuggingfacePretrainedTokenizer('bert-base-uncased')
data = ['hello there', 'this is a text']
hf_tokenizer.fit(data) # only loads the model
hf_tokenizer.transform(data)
# use hyper params from pretrained tokenizer to fit on own data
hf_tokenizer = HuggingfacePretrainedTokenizer(
'bert-base-uncased', train=True, vocab_size=12345)
data = ...
hf_tokenizer.fit(data) # fits new tokenizer on data
hf_tokenizer.transform(data)
We also skorch.hf.HuggingfaceTokenizer
if you don’t want to use a
pretrained tokenizer but instead want to train your own tokenizer with
fine-grained control over each component, like which tokenization method to use.
Of course, since both transformers are scikit-learn compatible, you can use them in a grid search.
Transformers¶
The Hugging Face transformers library gives you access to many pretrained deep learning models. There is no special skorch integration for those, since they’re just normal models and can thus be used without further adjustments (as long as they’re PyTorch models).
If you want to see how using transformers
with skorch could look like in
practice, take a look at the Hugging Face fine-tuning notebook.
FAQ¶
How do I apply L2 regularization?¶
To apply L2 regularization (aka weight decay), PyTorch supplies
the weight_decay
parameter, which must be supplied to the
optimizer. To pass this variable in skorch, use the
double-underscore notation for the optimizer:
net = NeuralNet(
...,
optimizer__weight_decay=0.01,
)
How can I continue training my model?¶
By default, when you call fit()
more than
once, the training starts from zero instead of from where it was left.
This is in line with sklearn’s behavior but not always desired. If
you would like to continue training, use
partial_fit()
instead of
fit()
. Alternatively, there is the
warm_start
argument, which is False
by default. Set it to
True
instead and you should be fine.
How do I shuffle my train batches?¶
skorch uses DataLoader
from PyTorch under
the hood. This class takes a couple of arguments, for instance
shuffle
. We therefore need to pass the shuffle
argument to
DataLoader
, which we achieve by using the
double-underscore notation (as known from sklearn):
net = NeuralNet(
...,
iterator_train__shuffle=True,
)
Note that we have an iterator_train
for the training data and an
iterator_valid
for validation and test data. In general, you only
want to shuffle the train data, which is what the code above does.
How do I use sklearn GridSeachCV when my data is in a dictionary?¶
skorch supports dicts as input but sklearn doesn’t. To get around
that, try to wrap your dictionary into a SliceDict
. This is
a data container that partly behaves like a dict, partly like an
ndarray. For more details on how to do this, have a look at the
coresponding data section
in the notebook.
How do I use sklearn GridSeachCV when my data is in a dataset?¶
skorch supports datasets as input but sklearn doesn’t. If it’s possible, you should provide your data in a non-dataset format, e.g. as a numpy array or torch tensor, extracted from your original dataset.
Sometimes, this is not possible, e.g. when your data doesn’t fit into
memory. To get around that, try to wrap your dataset into a
SliceDataset
. This is a data container that partly behaves
like a dataset, partly like an ndarray. Further information can be
found here: SliceDataset.
I want to use sample_weight, how can I do this?¶
Some scikit-learn models support to pass a sample_weight
argument
to fit
calls as part of the fit_params
. This allows you to
give different samples different weights in the final loss
calculation.
In general, skorch supports fit_params
, but unfortunately just
calling net.fit(X, y, sample_weight=sample_weight)
is not enough,
because the fit_params
are not split into train and valid, and are
not batched, resulting in a mismatch with the training batches.
Fortunately, skorch supports passing dictionaries as arguments, which
are actually split into train and valid and then batched. Therefore,
the best solution is to pass the sample_weight
with X
as a
dictionary. Below, there is example code on how to achieve this:
X, y = get_data()
# put your X into a dict if not already a dict
X = {'data': X}
# add sample_weight to the X dict
X['sample_weight'] = sample_weight
class MyModule(nn.Module):
...
def forward(self, data, sample_weight):
# when X is a dict, its keys are passed as kwargs to forward, thus
# our forward has to have the arguments 'data' and 'sample_weight';
# usually, sample_weight can be ignored here
...
class MyNet(NeuralNet):
def __init__(self, *args, criterion__reduce=False, **kwargs):
# make sure to set reduce=False in your criterion, since we need the loss
# for each sample so that it can be weighted
super().__init__(*args, criterion__reduce=criterion__reduce, **kwargs)
def get_loss(self, y_pred, y_true, X, *args, **kwargs):
# override get_loss to use the sample_weight from X
loss_unreduced = super().get_loss(y_pred, y_true, X, *args, **kwargs)
sample_weight = skorch.utils.to_tensor(X['sample_weight'], device=self.device)
loss_reduced = (sample_weight * loss_unreduced).mean()
return loss_reduced
net = MyNet(MyModule, ...)
net.fit(X, y)
I already split my data into training and validation sets, how can I use them?¶
If you have predefined training and validation datasets that are
subclasses of PyTorch Dataset
, you can use
predefined_split()
to wrap your validation dataset and
pass it to NeuralNet
’s train_split
parameter:
from skorch.helper import predefined_split
net = NeuralNet(
...,
train_split=predefined_split(valid_ds)
)
net.fit(train_ds)
If you split your data by using train_test_split()
,
you can create your own skorch Dataset
, and then pass
it to predefined_split()
:
from sklearn.model_selection import train_test_split
from skorch.helper import predefined_split
from skorch.dataset import Dataset
X_train, X_test, y_train, y_test = train_test_split(X, y)
valid_ds = Dataset(X_test, y_test)
net = NeuralNet(
...,
train_split=predefined_split(valid_ds)
)
net.fit(X_train, y_train)
What happens when NeuralNet is passed an initialized Pytorch module?¶
When NeuralNet
is passed an initialized Pytorch module,
skorch will usually leave the module alone. In the following example, the
resulting module will be trained for 20 epochs:
class MyModule(nn.Module):
def __init__(self, hidden=10):
...
module = MyModule()
net1 = NeuralNet(module, max_epochs=10, ...)
net1.fit(X, y)
net2 = NeuralNet(module, max_epochs=10, ...)
net2.fit(X, y)
When the module is passed to the second NeuralNet
, it
will not be re-initialized and will keep its parameters from the first 10
epochs.
When the module parameters are set through keywords arguments,
NeuralNet
will re-initialized the module:
net = NeuralNet(module, module__hidden=10, ...)
net.fit(X, y)
Although it is possible to pass an initialized Pytorch module to
NeuralNet
, it is recommended to pass the module class
instead:
net = NeuralNet(MyModule, ...)
net.fit(X, y)
In this case, fit()
will always re-initialize
the model and partial_fit()
won’t after the
network is initialized once.
How do I use a PyTorch Dataset with skorch?¶
skorch supports PyTorch’s Dataset
as arguments to
fit()
or
partial_fit()
. We create a dataset by
subclassing PyTorch’s Dataset
:
import torch.utils.data
class RandomDataset(torch.utils.data.Dataset):
def __init__(self):
self.X = torch.randn(128, 10)
self.Y = torch.randn(128, 10)
def __getitem__(self, idx):
return self.X[idx], self.Y[idx]
def __len__(self):
return 128
skorch expects the output of __getitem__
to be a tuple of two values.
The RandomDataset
can be passed directly to
fit()
:
from skorch import NeuralNet
import torch.nn as nn
train_ds = RandomDataset()
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, X):
return self.layer(X)
net = NeuralNet(MyModule, criterion=torch.nn.MSELoss)
net.fit(train_ds)
How can I deal with multiple return values from forward?¶
skorch supports modules that return multiple values. To do this,
simply return a tuple of all values that you want to return from the
forward
method. However, this tuple will also be passed to the
criterion. If the criterion cannot deal with multiple values, this
will result in an error.
To remedy this, you need to either implement your own criterion that
can deal with the output or you need to override
get_loss()
and handle the unpacking of the
tuple.
To inspect all output values, you can use either the
forward()
method (eager) or the
forward_iter()
method (lazy).
For an example of how this works, have a look at this notebook.
How can I perform gradient accumulation with skorch?¶
There is no direct option to turn on gradient accumulation (at least for now). However, with a few modifications, you can implement gradient accumulation yourself:
ACC_STEPS = 2 # number of steps to accumulate before updating weights
class GradAccNet(NeuralNetClassifier):
"""Net that accumulates gradients"""
def __init__(self, *args, acc_steps=ACC_STEPS, **kwargs):
super().__init__(*args, **kwargs)
self.acc_steps = acc_steps
def get_loss(self, *args, **kwargs):
loss = super().get_loss(*args, **kwargs)
return loss / self.acc_steps # normalize loss
def train_step(self, batch, **fit_params):
"""Perform gradient accumulation
Only optimize every nth batch.
"""
# note that n_train_batches starts at 1 for each epoch
n_train_batches = len(self.history[-1, 'batches'])
step = self.train_step_single(batch, **fit_params)
if n_train_batches % self.acc_steps == 0:
self.optimizer_.step()
self.optimizer_.zero_grad()
return step
This is not a complete recipe. For example, if you optimize every 2nd step, and the number of training batches is uneven, you should make sure that there is an optimization step after the last batch of each epoch. However, this example can serve as a starting point to implement your own version gradient accumulation.
Alternatively, make use of skorch’s accelerate integration provided by
AccelerateMixin
and use the gradient accumulation feature
from that library.
How can I dynamically set the input size of the PyTorch module based on the data?¶
Typically, it’s up to the user to determine the shape of the input
data when defining the PyTorch module. This can sometimes be
inconvenient, e.g. when the shape is only known at runtime. E.g., when
using sklearn.feature_selection.VarianceThreshold
, you cannot
know the number of features in advance. The best solution would be to
set the input size dynamically.
In most circumstances, this can be achieved with a few lines of code in skorch. Here is an example:
class InputShapeSetter(skorch.callbacks.Callback):
def on_train_begin(self, net, X, y):
net.set_params(module__input_dim=X.shape[1])
net = skorch.NeuralNetClassifier(
ClassifierModule,
callbacks=[InputShapeSetter()],
)
This assumes that your module accepts an argument called
input_units
, which determines the number of units of the input
layer, and that the number of features can be determined by
X.shape[1]
. If those assumptions are not true for your case,
adjust the code accordingly. A fully working example can be found
on stackoverflow.
How do I implement a score method on the net that returns the loss?¶
Sometimes, it is useful to be able to compute the loss of a net from within
skorch
(e.g. when a net is part of an sklearn
pipeline). The function
skorch.scoring.loss_scoring()
achieves this. Two examples are provided
below. The first demonstrates how to use skorch.scoring.loss_scoring()
as
a function on a trained net
object.
from skorch.scoring import loss_scoring
X = np.random.randn(250, 25).astype('float32')
y = (X.dot(np.ones(25)) > 0).astype(int)
module = nn.Sequential(
nn.Linear(25, 25),
nn.ReLU(),
nn.Linear(25, 2),
nn.Softmax(dim=1)
)
net = skorch.NeuralNetClassifier(module).fit(X, y)
print(loss_scoring(net, X, y))
The second example shows how to sub-class skorch.classifier.NeuralNetClassifier
to
implement a score
method. In this example, the score
method returns the
negative of the loss value, because we want
sklearn.model_selection.GridSearchCV
to return the run with least
loss and sklearn.model_selection.GridSearchCV
searches for the run with
the greatest score.
class ScoredNet(skorch.NeuralNetClassifier):
def score(self, X, y=None):
loss_value = loss_scoring(self, X, y)
return -loss_value
net = ScoredNet(module)
grid_searcher = GridSearchCV(
net, {'lr': [1e-2, 1e-3], 'batch_size': [8, 16]},
)
grid_searcher.fit(X, y)
best_net = grid_searcher.best_estimator_
print(best_net.score(X, y))
How can I set the random seed of my model?¶
skorch does not provide an unified way for setting the seed of your model. You can utilize the numpy and torch seeding interfaces to set random seeds before model initialization and will get consistent results if you do not employ other libraries that introduce randomness.
Here’s an example:
seed = 42
numpy.random.seed(seed)
torch.manual_seed(seed)
net = NeuralNet(
module=...
)
net.fit(X, y)
Note: torch.manual_seed calls torch.cuda.manual_seed if the current process employs CUDA, so you don’t have to worry about seeding CUDA RNG explicitly.
There are cases where you want a fixed train/validation split (by default the split is not seeded). This can be done by passing the random_state parameter to ValidSplit:
from skorch.dataset import ValidSplit
seed = 42
net = NeuralNet(
module=...,
train_split=ValidSplit(random_state=seed),
)
net.fit(X, y)
Note that there are other places where randomness is introduced
that skorch does not control, such as the torch DataLoader
when
setting shuffle=True
. See the corresponding
documentation
on how to fix the random seed in this case.
Migration guide¶
Migration from 0.10 to 0.11¶
With skorch 0.11, we pushed the tuple unpacking of values returned by the iterator to methods lower down the call chain. This way, it is much easier to work with iterators that don’t return exactly two values, as per the convention.
A consequence of this is a change in signature of these methods:
skorch.net.NeuralNet.train_step_single()
skorch.net.NeuralNet.validation_step()
skorch.callbacks.Callback.on_batch_begin()
skorch.callbacks.Callback.on_batch_end()
Instead of receiving the unpacked tuple of X
and y
, they just
receive a batch
, which is whatever is returned by the
iterator. The tuple unpacking needs to be performed inside these
methods.
If you have customized any of these methods, it is easy to retrieve
the previous behavior. E.g. if you wrote your own on_batch_begin
,
this is how to make the transition:
# before
def on_batch_begin(self, net, X, y, ...):
...
# after
def on_batch_begin(self, net, batch, ...):
X, y = batch
...
The same goes for the other three methods.
Migration from 0.11 to 0.12¶
In skorch 0.12, we made a change regarding the training step. Now, we initialize
the torch.utils.data.DataLoader
only once per fit call instead of once
per epoch. This is accomplished by calling
skorch.net.NeuralNet.get_iterator()
only once at the beginning of the
training process. For the majority of the users, this should make no difference
in practice.
However, you might be affected if you wrote a custom
skorch.net.NeuralNet.run_single_epoch()
. The first argument to this
method is now the initialized DataLoader
instead of a Dataset
.
Therefore, this method should no longer call
skorch.net.NeuralNet.get_iterator()
. You only need to change a few
lines of code to accomplish this, as shown below:
# before
def run_single_epoch(self, dataset, ...):
...
for batch in self.get_iterator(dataset, training=training):
...
# after
def run_single_epoch(self, iterator, ...):
...
for batch in iterator:
...
Your old code should still work for the time being but will give a
DeprecationWarning
. Starting from skorch v0.13, old code will raise an error
instead.
If it is necessary to have access to the Dataset
inside of
run_single_epoch
, you can access it on the DataLoader
object using
iterator.dataset
.
API Reference¶
If you are looking for information on a specific function, class or method, this part of the documentation is for you.
skorch¶
skorch.callbacks¶
This module serves to elevate callbacks in submodules to the skorch.callback namespace. Remember to define __all__ in each submodule.
-
class
skorch.callbacks.
BatchScoring
(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]¶ Callback that performs generic scoring on batches.
This callback determines the score after each batch and stores it in the net’s history in the column given by
name
. At the end of the epoch, the average of the scores are determined and also stored in the history. Furthermore, it is determined whether this average score is the best score yet and that information is also stored in the history.In contrast to
EpochScoring
, this callback determines the score for each batch and then averages the score at the end of the epoch. This can be disadvantageous for some scores if the batch size is small – e.g. area under the ROC will return incorrect scores in this case. Therefore, it is recommnded to useEpochScoring
unless you really need the scores for each batch.If
y
is None, thescoring
function with signature (model, X, y) must be able to handleX
as aTensor
andy=None
.Parameters: - scoring : None, str, or callable
If None, use the
score
method of the model. If str, it should be a valid sklearn metric (e.g. “f1_score”, “accuracy_score”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoring
parameter in sklearn’sGridSearchCV
et al.- lower_is_better : bool (default=True)
Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better.
- on_train : bool (default=False)
Whether this should be called during train or validation.
- name : str or None (default=None)
If not an explicit string, tries to infer the name from the
scoring
argument.- target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
- use_caching : bool (default=True)
Re-use the model’s prediction for computing the loss to calculate the score. Turning this off will result in an additional inference step for each batch.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, batch, training, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, X, y, **kwargs)Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_avg_score get_params set_params
-
class
skorch.callbacks.
Callback
[source]¶ Base class for callbacks.
All custom callbacks should inherit from this class. The subclass may override any of the
on_...
methods. It is, however, not necessary to override all of them, since it’s okay if they don’t have any effect.Classes that inherit from this also gain the
get_params
andset_params
method.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params -
initialize
()[source]¶ (Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.
This method should return self.
-
on_batch_begin
(net, batch=None, training=None, **kwargs)[source]¶ Called at the beginning of each batch.
-
on_epoch_begin
(net, dataset_train=None, dataset_valid=None, **kwargs)[source]¶ Called at the beginning of each epoch.
-
on_epoch_end
(net, dataset_train=None, dataset_valid=None, **kwargs)[source]¶ Called at the end of each epoch.
-
-
class
skorch.callbacks.
Checkpoint
(monitor='valid_loss_best', f_params='params.pt', f_optimizer='optimizer.pt', f_criterion='criterion.pt', f_history='history.json', f_pickle=None, fn_prefix='', dirname='', event_name='event_cp', sink=<function noop>, load_best=False, **kwargs)[source]¶ Save the model during training if the given metric improved.
This callback works by default in conjunction with the validation scoring callback since it creates a
valid_loss_best
value in the history which the callback uses to determine if this epoch is save-worthy.You can also specify your own metric to monitor or supply a callback that dynamically evaluates whether the model should be saved in this epoch.
As checkpointing is often used in conjunction with early stopping there is a need to restore the state of the model to the best checkpoint after training is done. The checkpoint callback will do this for you if you wish.
Some or all of the following can be saved:
- model parameters (see
f_params
parameter); - optimizer state (see
f_optimizer
parameter); - criterion state (see
f_criterion
parameter); - training history (see
f_history
parameter); - entire model object (see
f_pickle
parameter).
If you’ve created a custom module, e.g.
net.mymodule_
, you can save that as well by passingf_mymodule
.You can implement your own save protocol by subclassing
Checkpoint
and overridingsave_model()
.This callback writes a bool flag to the history column
event_cp
indicating whether a checkpoint was created or not.Example:
>>> net = MyNet(callbacks=[Checkpoint()]) >>> net.fit(X, y)
Example using a custom monitor where models are saved only in epochs where the validation and the train losses are best:
>>> monitor = lambda net: all(net.history[-1, ( ... 'train_loss_best', 'valid_loss_best')]) >>> net = MyNet(callbacks=[Checkpoint(monitor=monitor)]) >>> net.fit(X, y)
Parameters: - monitor : str, function, None
Value of the history to monitor or callback that determines whether this epoch should lead to a checkpoint. The callback takes the network instance as parameter.
In case
monitor
is set toNone
, the callback will save the network at every epoch.Note: If you supply a lambda expression as monitor, you cannot pickle the wrapper anymore as lambdas cannot be pickled. You can mitigate this problem by using importable functions instead.
- f_params : file-like object, str, None (default=’params.pt’)
File path to the file or file-like object where the model parameters should be saved. Pass
None
to disable saving model parameters.If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are
net
,last_epoch
andlast_batch
. Example to include last epoch number in file name:>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
- f_optimizer : file-like object, str, None (default=’optimizer.pt’)
File path to the file or file-like object where the optimizer state should be saved. Pass
None
to disable saving model parameters.Supports the same format specifiers as
f_params
.- f_criterion : file-like object, str, None (default=’criterion.pt’)
File path to the file or file-like object where the criterion state should be saved. Pass
None
to disable saving model parameters.Supports the same format specifiers as
f_params
.- f_history : file-like object, str, None (default=’history.json’)
File path to the file or file-like object where the model training history should be saved. Pass
None
to disable saving history.- f_pickle : file-like object, str, None (default=None)
File path to the file or file-like object where the entire model object should be pickled. Pass
None
to disable pickling.Supports the same format specifiers as
f_params
.- fn_prefix: str (default=’’)
Prefix for filenames. If
f_params
,f_optimizer
,f_history
, orf_pickle
are strings, they will be prefixed byfn_prefix
.- dirname: str (default=’’)
Directory where files are stored.
- load_best: bool (default=False)
Load the best checkpoint automatically once training ended. This can be particularly helpful in combination with early stopping as it allows for scoring with the best model, even when early stopping ended training a number of epochs later. Note that this will only work when
monitor != None
.- event_name: str, (default=’event_cp’)
Name of event to be placed in history when checkpoint is triggered. Pass
None
to disable placing events in history.- sink : callable (default=noop)
The target that the information about created checkpoints is sent to. This can be a logger or
print
function (to send to stdout). By default the output is discarded.
Attributes: - f_history_
Methods
get_formatted_files
(net)Returns a dictionary of formatted filenames initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net, **kwargs)Called at the end of training. save_model
(net)Save the model. get_params set_params - model parameters (see
-
class
skorch.callbacks.
EarlyStopping
(monitor='valid_loss', patience=5, threshold=0.0001, threshold_mode='rel', lower_is_better=True, sink=<built-in function print>, load_best=False)[source]¶ Callback for stopping training when scores don’t improve.
Stop training early if a specified monitor metric did not improve in patience number of epochs by at least threshold.
Parameters: - monitor : str (default=’valid_loss’)
Value of the history to monitor to decide whether to stop training or not. The value is expected to be double and is commonly provided by scoring callbacks such as
skorch.callbacks.EpochScoring
.- lower_is_better : bool (default=True)
Whether lower scores should be considered better or worse.
- patience : int (default=5)
Number of epochs to wait for improvement of the monitor value until the training process is stopped.
- threshold : int (default=1e-4)
Ignore score improvements smaller than threshold.
- threshold_mode : str (default=’rel’)
One of rel, abs. Decides whether the threshold value is interpreted in absolute terms or as a fraction of the best score so far (relative)
- sink : callable (default=print)
The target that the information about early stopping is sent to. By default, the output is printed to stdout, but the sink could also be a logger or
noop()
.- load_best: bool (default=False)
Whether to restore module weights from the epoch with the best value of the monitored quantity. If False, the module weights obtained at the last step of training are used. Note that only the module is restored. Use the
Checkpoint
callback with theload_best
argument set toTrue
if you need to restore the whole object.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, **kwargs)Called at the beginning of training. on_train_end
(net, **kwargs)Called at the end of training. get_params set_params
-
class
skorch.callbacks.
EpochScoring
(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]¶ Callback that performs generic scoring on predictions.
At the end of each epoch, this callback makes a prediction on train or validation data, determines the score for that prediction and whether it is the best yet, and stores the result in the net’s history.
In case you already computed a score value for each batch you can omit the score computation step by return the value from the history. For example:
>>> def my_score(net, X=None, y=None): ... losses = net.history[-1, 'batches', :, 'my_score'] ... batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size'] ... return np.average(losses, weights=batch_sizes) >>> net = MyNet(callbacks=[ ... ('my_score', Scoring(my_score, name='my_score'))
If you fit with a custom dataset, this callback should work as expected as long as
use_caching=True
which enables the collection ofy
values from the dataset. If you decide to disable the caching of predictions andy
values, you need to write your own scoring function that is able to deal with the dataset and returns a scalar, for example:>>> def ds_accuracy(net, ds, y=None): ... # assume ds yields (X, y), e.g. torchvision.datasets.MNIST ... y_true = [y for _, y in ds] ... y_pred = net.predict(ds) ... return sklearn.metrics.accuracy_score(y_true, y_pred) >>> net = MyNet(callbacks=[ ... EpochScoring(ds_accuracy, use_caching=False)]) >>> ds = torchvision.datasets.MNIST(root=mnist_path) >>> net.fit(ds)
Parameters: - scoring : None, str, or callable (default=None)
If None, use the
score
method of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoring
parameter in sklearn’sGridSearchCV
et al.- lower_is_better : bool (default=True)
Whether lower scores should be considered better or worse.
- on_train : bool (default=False)
Whether this should be called during train or validation data.
- name : str or None (default=None)
If not an explicit string, tries to infer the name from the
scoring
argument.- target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
- use_caching : bool (default=True)
Collect labels and predictions (
y_true
andy_pred
) over the course of one epoch and use the cached values for computing the score. The cached values are shared between allEpochScoring
instances. Disabling this will result in an additional inference step for each epoch and an inability to use arbitrary datasets as input (since we don’t know how to extracty_true
from an arbitrary dataset).
Methods
get_test_data
(dataset_train, dataset_valid)Return data needed to perform scoring. initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, batch, y_pred, training, …)Called at the end of each batch. on_epoch_begin
(net, dataset_train, …)Called at the beginning of each epoch. on_epoch_end
(net, dataset_train, …)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, X, y, **kwargs)Called at the beginning of training. on_train_end
(*args, **kwargs)Called at the end of training. get_params set_params -
get_test_data
(dataset_train, dataset_valid)[source]¶ Return data needed to perform scoring.
This is a convenience method that handles picking of train/valid, different types of input data, use of cache, etc. for you.
Parameters: - dataset_train
Incoming training data or dataset.
- dataset_valid
Incoming validation data or dataset.
Returns: - X_test
Input data used for making the prediction.
- y_test
Target ground truth. If caching was enabled, return cached y_test.
- y_pred : list
The predicted targets. If caching was disabled, the list is empty. If caching was enabled, the list contains the batches of the predictions. It may thus be necessary to concatenate the output before working with it:
y_pred = np.concatenate(y_pred)
-
initialize
()[source]¶ (Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.
This method should return self.
-
class
skorch.callbacks.
EpochTimer
(**kwargs)[source]¶ Measures the duration of each epoch and writes it to the history with the name
dur
.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
Freezer
(*args, **kwargs)[source]¶ Freeze matching parameters at the start of the first epoch. You may specify a specific point in time (either by epoch number or using a callable) when the parameters are frozen using the
at
parameter.See
ParamMapper
for details.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
GradientNormClipping
(gradient_clip_value=None, gradient_clip_norm_type=2)[source]¶ Clips gradient norm of a module’s parameters.
The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
See
torch.nn.utils.clip_grad_norm_()
for more information.Parameters: - gradient_clip_value : float (default=None)
If not None, clip the norm of all model parameter gradients to this value. The type of the norm is determined by the
gradient_clip_norm_type
parameter and defaults to L2.- gradient_clip_norm_type : float (default=2)
Norm to use when gradient clipping is active. The default is to use L2-norm. Can be ‘inf’ for infinity norm.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(_, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
Initializer
(*args, **kwargs)[source]¶ Apply any function on matching parameters in the first epoch.
Examples
Use
Initializer
to initialize all dense layer weights with values sampled from an uniform distribution on the beginning of the first epoch:>>> init_fn = partial(torch.nn.init.uniform_, a=-1e-3, b=1e-3) >>> cb = Initializer('dense*.weight', fn=init_fn) >>> net = Net(myModule, callbacks=[cb])
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
InputShapeSetter
(param_name='input_dim', input_dim_fn=None, module_name='module')[source]¶ Sets the input dimension of the PyTorch module to the input dimension of the training data. By default the last dimension of X (
X.shape[-1]
) will be used.This can be of use when the shape of X is not known beforehand, e.g. when using a skorch model within an sklearn pipeline and grid-searching feature transformers, or using feature selection methods.
Basic usage:
>>> class MyModule(torch.nn.Module): ... def __init__(self, input_dim=1): ... super().__init__() ... self.layer = torch.nn.Linear(input_dim, 3) ... # ... >>> X1 = np.zeros(100, 5) >>> X2 = np.zeros(100, 3) >>> y = np.zeros(100) >>> net = NeuralNetClassifier(MyModule, callbacks=[InputShapeSetter()]) >>> net.fit(X1, y) # self.module_.layer.in_features == 5 >>> net.fit(X2, y) # self.module_.layer.in_features == 3
Parameters: - param_name : str (default=’input_dim’)
The parameter name is the parameter your model uses to define the input dimension in its
__init__
method.- input_dim_fn : callable, None (default=None)
In case your
X
value is more complex and deriving the input dimension is not as easy asX.shape[-1]
you can pass a callable to this parameter which takesX
and returns the input dimension.- module_name : str (default=’module’)
Only needs change when you are using more than one module in your skorch model (e.g., in case of GANs).
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, X, y, **kwargs)Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_input_dim get_params set_params
-
class
skorch.callbacks.
LRScheduler
(policy='WarmRestartLR', monitor='train_loss', event_name='event_lr', step_every='epoch', **kwargs)[source]¶ Callback that sets the learning rate of each parameter group according to some policy.
Parameters: - policy : str or _LRScheduler class (default=’WarmRestartLR’)
Learning rate policy name or scheduler to be used.
- monitor : str or callable (default=None)
Value of the history to monitor or function/callable. In the latter case, the callable receives the net instance as argument and is expected to return the score (float) used to determine the learning rate adjustment.
- event_name: str, (default=’event_lr’)
Name of event to be placed in history when the scheduler takes a step. Pass
None
to disable placing events in history. Note: This feature works only for pytorch version >=1.4- step_every: str, (default=’epoch’)
- Value for when to apply the learning scheduler step. Can be either ‘batch’
or ‘epoch’.
- kwargs
Additional arguments passed to the lr scheduler.
Attributes: - kwargs
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, training, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, **kwargs)Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. simulate
(steps, initial_lr)Simulates the learning rate scheduler. get_params set_params
-
class
skorch.callbacks.
LoadInitState
(checkpoint)[source]¶ Loads the model, optimizer, and history from a checkpoint into a
NeuralNet
when training begins.Parameters: - checkpoint: :class:`.Checkpoint`
Checkpoint to get filenames from.
Examples
Consider running the following example multiple times:
>>> cp = Checkpoint(monitor='valid_loss_best') >>> load_state = LoadInitState(cp) >>> net = NeuralNet(..., callbacks=[cp, load_state]) >>> net.fit(X, y)
On the first run, the
Checkpoint
saves the model, optimizer, and history when the validation loss is minimized. During the first run, there are no files on disk, thusLoadInitState
will not load anything. When running the example a second time,LoadInitState
will load the best model from the first run and continue training from there.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
MlflowLogger
(run=None, client=None, create_artifact=True, terminate_after_train=True, log_on_batch_end=False, log_on_epoch_end=True, batch_suffix=None, epoch_suffix=None, keys_ignored=None)[source]¶ Logs results from history and artifact to Mlflow
“MLflow is an open source platform for managing the end-to-end machine learning lifecycle” (MLflow Documentation)
Use this callback to automatically log your metrics and create/log artifacts to mlflow.
The best way to log additional information is to log directly to the experiment object or subclass the
on_*
methods.To use this logger, you first have to install Mlflow:
Parameters: - run : mlflow.entities.Run (default=None)
Instantiated
mlflow.entities.Run
class. By default (if set toNone
),mlflow.active_run()
is used to get the current run.- client : mlflow.tracking.MlflowClient (default=None)
Instantiated
mlflow.tracking.MlflowClient
class. By default (if set toNone
),MlflowClient()
is used, which by default has:- the tracking URI set by
mlflow.set_tracking_uri()
- the registry URI set by
mlflow.set_registry_uri()
- the tracking URI set by
- create_artifact : bool (default=True)
Whether to create artifacts for the network’s params, optimizer, criterion and history. See Saving and Loading
- terminate_after_train : bool (default=True)
Whether to terminate the
Run
object once training finishes.- log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.
- log_on_epoch_end : bool (default=True)
Whether to log loss and other metrics on epoch level.
- batch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
None
)'_batch'
is used if batch and epoch logging are both enabled and no suffix is used otherwise.- epoch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
None
)'_epoch'
is used if batch and epoch logging are both enabled and no suffix is used otherwise.- keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to Mlflow. Note that in addition to the keys provided by the user, keys such as those starting with
'event_'
or ending on'_best'
are ignored by default.
Examples
Mlflow fluent API:
>>> import mlflow >>> net = NeuralNetClassifier(net, callbacks=[MLflowLogger()]) >>> with mlflow.start_run(): ... net.fit(X, y)
Custom
run
andclient
:>>> from mlflow.tracking import MlflowClient >>> client = MlflowClient() >>> experiment = client.get_experiment_by_name('Default') >>> run = client.create_run(experiment.experiment_id) >>> net = NeuralNetClassifier(..., callbacks=[MlflowLogger(run, client)]) >>> net.fit(X, y)
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, training, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, **kwargs)Called at the beginning of training. on_train_end
(net, **kwargs)Called at the end of training. get_params set_params
-
class
skorch.callbacks.
NeptuneLogger
(experiment, log_on_batch_end=False, close_after_train=True, keys_ignored=None)[source]¶ Logs results from history to Neptune
Neptune is a lightweight experiment tracking tool. You can read more about it here: https://neptune.ai
Use this callback to automatically log all interesting values from your net’s history to Neptune.
The best way to log additional information is to log directly to the experiment object or subclass the
on_*
methods.To monitor resource consumption install psutil
>>> python -m pip install psutil
You can view example experiment logs here: https://ui.neptune.ai/o/shared/org/skorch-integration/e/SKOR-13/charts
Parameters: - experiment : neptune.experiments.Experiment
Instantiated
Experiment
class.- log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.
- close_after_train : bool (default=True)
Whether to close the
Experiment
object once training finishes. Set this parameter to False if you want to continue logging to the same Experiment or if you use it as a context manager.- keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to Neptune. Note that in addition to the keys provided by the user, keys such as those starting with
'event_'
or ending on'_best'
are ignored by default.
Examples
>>> # Install neptune >>> python -m pip install neptune-client >>> # Create a neptune experiment object >>> import neptune ... ... # We are using api token for an anonymous user. ... # For your projects use the token associated with your neptune.ai account >>> neptune.init(api_token='ANONYMOUS', ... project_qualified_name='shared/skorch-integration') ... ... experiment = neptune.create_experiment( ... name='skorch-basic-example', ... params={'max_epochs': 20, ... 'lr': 0.01}, ... upload_source_files=['skorch_example.py'])
>>> # Create a neptune_logger callback >>> neptune_logger = NeptuneLogger(experiment, close_after_train=False)
>>> # Pass a logger to net callbacks argument >>> net = NeuralNetClassifier( ... ClassifierModule, ... max_epochs=20, ... lr=0.01, ... callbacks=[neptune_logger])
>>> # Log additional metrics after training has finished >>> from sklearn.metrics import roc_auc_score ... y_pred = net.predict_proba(X) ... auc = roc_auc_score(y, y_pred[:, 1]) ... ... neptune_logger.experiment.log_metric('roc_auc_score', auc)
>>> # log charts like ROC curve ... from scikitplot.metrics import plot_roc ... import matplotlib.pyplot as plt ... ... fig, ax = plt.subplots(figsize=(16, 12)) ... plot_roc(y, y_pred, ax=ax) ... neptune_logger.experiment.log_image('roc_curve', fig)
>>> # log net object after training ... net.save_params(f_params='basic_model.pkl') ... neptune_logger.experiment.log_artifact('basic_model.pkl')
>>> # close experiment ... neptune_logger.experiment.stop()
Attributes: - first_batch_ : bool
Helper attribute that is set to True at initialization and changes to False on first batch end. Can be used when we want to log things exactly once.
- .. _Neptune: https://www.neptune.ai
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Automatically log values from the last history step. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net, **kwargs)Called at the end of training. get_params set_params
-
class
skorch.callbacks.
ParamMapper
(patterns, fn=<function noop>, at=1, schedule=None)[source]¶ Map arbitrary functions over module parameters filtered by pattern matching.
In the simplest case the function is only applied once at the beginning of a given epoch (at
on_epoch_begin
) but more complex execution schemes (e.g. periodic application) are possible usingat
andscheduler
.Parameters: - patterns : str or callable or list
The pattern(s) to match parameter names against. Patterns are UNIX globbing patterns as understood by
fnmatch()
. Patterns can also be callables which will get called with the parameter name and are regarded as a match when the callable returns a truthy value.This parameter also supports lists of str or callables so that one
ParamMapper
can match a group of parameters.Example:
'linear*.weight'
or['linear0.*', 'linear1.bias']
orlambda name: name.startswith('linear')
.- fn : function
The function to apply to each parameter separately.
- at : int or callable
In case you specify an integer it represents the epoch number the function
fn
is applied to the parameters, in caseat
is a function it will receivenet
as parameter and the function is applied to the parameter onceat
returnsTrue
.- schedule : callable or None
If specified this callable supersedes the static
at
/fn
combination by dynamically returning the function that is applied on the matched parameters. This way you can, for example, create a schedule that periodically freezes and unfreezes layers.The callable’s signature is
schedule(net: NeuralNet) -> callable
.
Notes
When starting the training process after saving and loading a model,
ParamMapper
might re-initialize parts of your model when the history is not saved along with the model. To avoid this, in case you useParamMapper
(or subclasses, e.g.Initializer
) and want to save your model make sure to either (a) use pickle, (b) save and load the history or (c) remove the parameter mapper callbacks before continuing training.Examples
Initialize a layer on first epoch before the first training step:
>>> init = partial(torch.nn.init.uniform_, a=0, b=1) >>> cb = ParamMapper('linear*.weight', at=1, fn=init) >>> net = Net(myModule, callbacks=[cb])
Reset layer initialization if train loss reaches a certain value (e.g. re-initialize on overfit):
>>> at = lambda net: net.history[-1, 'train_loss'] < 0.1 >>> init = partial(torch.nn.init.uniform_, a=0, b=1) >>> cb = ParamMapper('linear0.weight', at=at, fn=init) >>> net = Net(myModule, callbacks=[cb])
Periodically freeze and unfreeze all embedding layers:
>>> def my_sched(net): ... if len(net.history) % 2 == 0: ... return skorch.utils.freeze_parameter ... else: ... return skorch.utils.unfreeze_parameter >>> cb = ParamMapper('embedding*.weight', schedule=my_sched) >>> net = Net(myModule, callbacks=[cb])
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
PassthroughScoring
(name, lower_is_better=True, on_train=False)[source]¶ Creates scores on epoch level based on batch level scores
This callback doesn’t calculate any new scores but instead passes through a score that was created on the batch level. Based on that score, an average across the batch is created (honoring the batch size) and recorded in the history for the given epoch.
Use this callback when there already is a score calculated on the batch level. If that score has yet to be calculated, use
BatchScoring
instead.Parameters: - name : str
Name of the score recorded on a batch level in the history.
- lower_is_better : bool (default=True)
Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better.
- on_train : bool (default=False)
Whether this should be called during train or validation.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_avg_score get_params set_params
-
class
skorch.callbacks.
PrintLog
(keys_ignored=None, sink=<built-in function print>, tablefmt='simple', floatfmt='.4f', stralign='right')[source]¶ Print useful information from the model’s history as a table.
By default,
PrintLog
prints everything from the history except for'batches'
.To determine the best loss,
PrintLog
looks for keys that end on'_best'
and associates them with the corresponding loss. E.g.,'train_loss_best'
will be matched with'train_loss'
. Theskorch.callbacks.EpochScoring
callback takes care of creating those entries, which is whyPrintLog
works best in conjunction with that callback.PrintLog
treats keys with the'event_'
prefix in a special way. They are assumed to contain information about occasionally occuring events. TheFalse
orNone
entries (indicating that an event did not occur) are not printed, resulting in empty cells in the table, andTrue
entries are printed with+
symbol.PrintLog
groups all event columns together and pushes them to the right, just before the'dur'
column.Note:
PrintLog
will not result in good outputs if the number of columns varies between epochs, e.g. if the valid loss is only present on every other epoch.Parameters: - keys_ignored : str or list of str (default=None)
Key or list of keys that should not be part of the printed table. Note that in addition to the keys provided by the user, keys such as those starting with
'event_'
or ending on'_best'
are ignored by default.- sink : callable (default=print)
The target that the output string is sent to. By default, the output is printed to stdout, but the sink could also be a logger, etc.
- tablefmt : str (default=’simple’)
The format of the table. See the documentation of the
tabulate
package for more detail. Can be ‘plain’, ‘grid’, ‘pipe’, ‘html’, ‘latex’, among others.- floatfmt : str (default=’.4f’)
The number formatting. See the documentation of the
tabulate
package for more details.- stralign : str (default=’right’)
The alignment of columns with strings. Can be ‘left’, ‘center’, ‘right’, or
None
(disable alignment). Default is ‘right’ (to be consistent with numerical columns).
Methods
format_row
(row, key, color)For a given row from the table, format it (i.e. initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params table -
format_row
(row, key, color)[source]¶ For a given row from the table, format it (i.e. floating points and color if applicable).
-
class
skorch.callbacks.
ProgressBar
(batches_per_epoch='auto', detect_notebook=True, postfix_keys=None)[source]¶ Display a progress bar for each epoch.
The progress bar includes elapsed and estimated remaining time for the current epoch, the number of batches processed, and other user-defined metrics. The progress bar is erased once the epoch is completed.
ProgressBar
needs to know the total number of batches per epoch in order to display a meaningful progress bar. By default, this number is determined automatically using the dataset length and the batch size. If this heuristic does not work for some reason, you may either specify the number of batches explicitly or let theProgressBar
count the actual number of batches in the previous epoch.For jupyter notebooks a non-ASCII progress bar can be printed instead. To use this feature, you need to have ipywidgets installed.
Parameters: - batches_per_epoch : int, str (default=’auto’)
Either a concrete number or a string specifying the method used to determine the number of batches per epoch automatically.
'auto'
means that the number is computed from the length of the dataset and the batch size.'count'
means that the number is determined by counting the batches in the previous epoch. Note that this will leave you without a progress bar at the first epoch.- detect_notebook : bool (default=True)
If enabled, the progress bar determines if its current environment is a jupyter notebook and switches to a non-ASCII progress bar.
- postfix_keys : list of str (default=[‘train_loss’, ‘valid_loss’])
You can use this list to specify additional info displayed in the progress bar such as metrics and losses. A prerequisite to this is that these values are residing in the history on batch level already, i.e. they must be accessible via
>>> net.history[-1, 'batches', -1, key]
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params in_ipynb set_params
-
class
skorch.callbacks.
TrainEndCheckpoint
(f_params='params.pt', f_optimizer='optimizer.pt', f_criterion='criterion.pt', f_history='history.json', f_pickle=None, fn_prefix='train_end_', dirname='', sink=<function noop>, **kwargs)[source]¶ Saves the model parameters, optimizer state, and history at the end of training. The default
fn_prefix
is'train_end_'
.Parameters: - f_params : file-like object, str, None (default=’params.pt’)
File path to the file or file-like object where the model parameters should be saved. Pass
None
to disable saving model parameters.If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are
net
,last_epoch
andlast_batch
. Example to include last epoch number in file name:>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
- f_optimizer : file-like object, str, None (default=’optimizer.pt’)
File path to the file or file-like object where the optimizer state should be saved. Pass
None
to disable saving model parameters.Supports the same format specifiers as
f_params
.- f_criterion : file-like object, str, None (default=’criterion.pt’)
File path to the file or file-like object where the criterion state should be saved. Pass
None
to disable saving model parameters.Supports the same format specifiers as
f_params
.- f_history : file-like object, str, None (default=’history.json’)
File path to the file or file-like object where the model training history should be saved. Pass
None
to disable saving history.- f_pickle : file-like object, str, None (default=None)
File path to the file or file-like object where the entire model object should be pickled. Pass
None
to disable pickling.Supports the same format specifiers as
f_params
.- fn_prefix: str (default=’train_end_’)
Prefix for filenames. If
f_params
,f_optimizer
,f_history
, orf_pickle
are strings, they will be prefixed byfn_prefix
.- dirname: str (default=’’)
Directory where files are stored.
- sink : callable (default=noop)
The target that the information about created checkpoints is sent to. This can be a logger or
print
function (to send to stdout). By default the output is discarded.
Examples
Consider running the following example multiple times:
>>> train_end_cp = TrainEndCheckpoint(dirname='exp1') >>> load_state = LoadInitState(train_end_cp) >>> net = NeuralNet(..., callbacks=[train_end_cp, load_state]) >>> net.fit(X, y)
After the first run, model parameters, optimizer state, and history are saved into a directory named exp1. On the next run, LoadInitState will load the state from the first run and continue training.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net, **kwargs)Called at the end of training. get_params set_params
-
class
skorch.callbacks.
TensorBoard
(writer, close_after_train=True, keys_ignored=None, key_mapper=<function rename_tensorboard_key>)[source]¶ Logs results from history to TensorBoard
“TensorBoard provides the visualization and tooling needed for machine learning experimentation” (offical docs).
Use this callback to automatically log all interesting values from your net’s history to tensorboard after each epoch.
Parameters: - writer : torch.utils.tensorboard.writer.SummaryWriter
Instantiated
SummaryWriter
class.- close_after_train : bool (default=True)
Whether to close the
SummaryWriter
object once training finishes. Set this parameter to False if you want to continue logging with the same writer or if you use it as a context manager.- keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to tensorboard. Note that in addition to the keys provided by the user, keys such as those starting with
'event_'
or ending on'_best'
are ignored by default.- key_mapper : callable or function (default=rename_tensorboard_key)
This function maps a key name from the history to a tag in tensorboard. This is useful because tensorboard can automatically group similar tags if their names start with the same prefix, followed by a forward slash. By default, this callback will prefix all keys that start with “train” or “valid” with the “Loss/” prefix.
Examples
Here is the standard way of using the callback:
>>> # Example: normal usage >>> from skorch.callbacks import TensorBoard >>> from torch.utils.tensorboard import SummaryWriter >>> writer = SummaryWriter(...) >>> net = NeuralNet(..., callbacks=[TensorBoard(writer)]) >>> net.fit(X, y)
The best way to log additional information is to subclass this callback and add your code to one of the
on_*
methods.>>> # Example: log the bias parameter as a histogram >>> def extract_bias(module): ... return module.hidden.bias >>> # override on_epoch_end >>> class MyTensorBoard(TensorBoard): ... def on_epoch_end(self, net, **kwargs): ... bias = extract_bias(net.module_) ... epoch = net.history[-1, 'epoch'] ... self.writer.add_histogram('bias', bias, global_step=epoch) ... super().on_epoch_end(net, **kwargs) # call super last >>> # other code >>> net = NeuralNet(..., callbacks=[MyTensorBoard(writer)])
Methods
add_scalar_maybe
(history, key, tag[, …])Add a scalar value from the history to TensorBoard initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Automatically log values from the last history step. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net, **kwargs)Called at the end of training. get_params set_params -
add_scalar_maybe
(history, key, tag, global_step=None)[source]¶ Add a scalar value from the history to TensorBoard
Will catch errors like missing keys or wrong value types.
Parameters: - history : skorch.History
History object saved as attribute on the neural net.
- key : str
Key of the desired value in the history.
- tag : str
Name of the tag used in TensorBoard.
- global_step : int or None
Global step value to record.
-
class
skorch.callbacks.
SacredLogger
(experiment, log_on_batch_end=False, log_on_epoch_end=True, batch_suffix=None, epoch_suffix=None, keys_ignored=None)[source]¶ Logs results from history to Sacred.
Sacred is a tool to help you configure, organize, log and reproduce experiments. Developed at IDSIA. See https://github.com/IDSIA/sacred.
Use this callback to automatically log all interesting values from your net’s history to Sacred.
If you want to log additional information, you can simply add it to
History
. See the documentation onCallbacks
, andScoring
for more information. Alternatively you can subclass this callback and extend theon_*
methods.To use this logger, you first have to install Sacred:
python -m pip install sacred
You might also install pymongo to use a mongodb backend. See the upstream documentation for more details. Once you have installed it, you can set up a simple experiment and pass this Logger as a callback to your skorch estimator:
Parameters: - experiment : sacred.Experiment
Instantiated
Experiment
class.- log_on_batch_end : bool (default=False)
Whether to log loss and other metrics on batch level.
- log_on_epoch_end : bool (default=True)
Whether to log loss and other metrics on epoch level.
- batch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
None
) “_batch” is used if batch and epoch logging are both enabled and no suffix is used otherwise.- epoch_suffix : str (default=None)
A string that will be appended to all logged keys. By default (if set to
None
) “_epoch” is used if batch and epoch logging are both enabled and no suffix is used otherwise.- keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to Sacred. Note that in addition to the keys provided by the user, keys such as those starting with
'event_'
or ending on'_best'
are ignored by default.
Examples
>>> # contents of sacred-experiment.py >>> import numpy as np >>> from sacred import Experiment >>> from sklearn.datasets import make_classification >>> from skorch.callbacks.logging import SacredLogger >>> from skorch.callbacks.scoring import EpochScoring >>> from skorch import NeuralNetClassifier >>> from skorch.toy import make_classifier >>> ex = Experiment() >>> @ex.config >>> def my_config(): ... max_epochs = 20 ... lr = 0.01 >>> X, y = make_classification() >>> X, y = X.astype(np.float32), y.astype(np.int64) >>> @ex.automain >>> def main(_run, max_epochs, lr): ... # Take care to add additional scoring callbacks *before* the logger. ... net = NeuralNetClassifier( ... make_classifier(), ... max_epochs=max_epochs, ... lr=0.01, ... callbacks=[EpochScoring("f1"), SacredLogger(_run)] ... ) ... # now fit your estimator to your data ... net.fit(X, y)
Then call this from the command line, e.g. like this:
python sacred-script.py with max_epochs=15
You can also change other options on the command line and optionally specify a backend.
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net, **kwargs)Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Automatically log values from the last history step. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
Unfreezer
(*args, **kwargs)[source]¶ Inverse operation of
Freezer
.Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net, **kwargs)Called at the beginning of each epoch. on_epoch_end
(net[, dataset_train, dataset_valid])Called at the end of each epoch. on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net[, X, y])Called at the beginning of training. on_train_end
(net[, X, y])Called at the end of training. filter_parameters get_params named_parameters set_params
-
class
skorch.callbacks.
WandbLogger
(wandb_run, save_model=True, keys_ignored=None)[source]¶ Logs best model and metrics to Weights & Biases
Use this callback to automatically log best trained model, all metrics from your net’s history, model topology and computer resources to Weights & Biases after each epoch.
Every file saved in wandb_run.dir is automatically logged to W&B servers.
See example run
Parameters: - wandb_run : wandb.wandb_run.Run
wandb Run used to log data.
- save_model : bool (default=True)
Whether to save a checkpoint of the best model and upload it to your Run on W&B servers.
- keys_ignored : str or list of str (default=None)
Key or list of keys that should not be logged to wandb. Note that in addition to the keys provided by the user, keys such as those starting with
'event_'
or ending on'_best'
are ignored by default.
Examples
>>> # Install wandb ... python -m pip install wandb
>>> import wandb >>> from skorch.callbacks import WandbLogger
>>> # Create a wandb Run ... wandb_run = wandb.init() >>> # Alternative: Create a wandb Run without having a W&B account ... wandb_run = wandb.init(anonymous="allow)
>>> # Log hyper-parameters (optional) ... wandb_run.config.update({"learning rate": 1e-3, "batch size": 32})
>>> net = NeuralNet(..., callbacks=[WandbLogger(wandb_run)]) >>> net.fit(X, y)
Methods
initialize
()(Re-)Set the initial state of the callback. on_batch_begin
(net[, batch, training])Called at the beginning of each batch. on_batch_end
(net[, batch, training])Called at the end of each batch. on_epoch_begin
(net[, dataset_train, …])Called at the beginning of each epoch. on_epoch_end
(net, **kwargs)Log values from the last history step and save best model on_grad_computed
(net, named_parameters[, X, …])Called once per batch after gradients have been computed but before an update step was performed. on_train_begin
(net, **kwargs)Log model topology and add a hook for gradients on_train_end
(net[, X, y])Called at the end of training. get_params set_params
-
class
skorch.callbacks.
WarmRestartLR
(optimizer, min_lr=1e-06, max_lr=0.05, base_period=10, period_mult=2, last_epoch=-1)[source]¶ Stochastic Gradient Descent with Warm Restarts (SGDR) scheduler.
This scheduler sets the learning rate of each parameter group according to stochastic gradient descent with warm restarts (SGDR) policy. This policy simulates periodic warm restarts of SGD, where in each restart the learning rate is initialize to some value and is scheduled to decrease.
Parameters: - optimizer : torch.optimizer.Optimizer instance.
Optimizer algorithm.
- min_lr : float or list of float (default=1e-6)
Minimum allowed learning rate during each period for all param groups (float) or each group (list).
- max_lr : float or list of float (default=0.05)
Maximum allowed learning rate during each period for all param groups (float) or each group (list).
- base_period : int (default=10)
Initial restart period to be multiplied at each restart.
- period_mult : int (default=2)
Multiplicative factor to increase the period between restarts.
- last_epoch : int (default=-1)
The index of the last valid epoch.
References
[1] Ilya Loshchilov and Frank Hutter, 2017, “Stochastic Gradient Descent with Warm Restarts,”. “ICLR” https://arxiv.org/pdf/1608.03983.pdf Methods
get_last_lr
()Return last computed learning rate by current scheduler. load_state_dict
(state_dict)Loads the schedulers state. print_lr
(is_verbose, group, lr[, epoch])Display the current learning rate. state_dict
()Returns the state of the scheduler as a dict
.get_lr step
skorch.classifier¶
NeuralNet subclasses for classification tasks.
-
class
skorch.classifier.
NeuralNetBinaryClassifier
(module, *args, criterion=<class 'torch.nn.modules.loss.BCEWithLogitsLoss'>, train_split=<skorch.dataset.ValidSplit object>, threshold=0.5, **kwargs)[source]¶ NeuralNet for binary classification tasks
Use this specifically if you have a binary classification task, with input data X and target y. y must be 1d.
In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:
>>> net = NeuralNet( ... ..., ... optimizer=torch.optimizer.SGD, ... optimizer__momentum=0.95, ...)
This way, when
optimizer
is initialized,NeuralNet
will take care of setting themomentum
parameter to 0.95.(Note that the double underscore notation in
optimizer__momentum
means that the parametermomentum
should be set on the objectoptimizer
. This is the same semantic as used by sklearn.)Furthermore, this allows to change those parameters later:
net.set_params(optimizer__momentum=0.99)
This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.
By default an
EpochTimer
,BatchScoring
(for both training and validation datasets), andPrintLog
callbacks are installed for the user’s convenience.Parameters: - module : torch module (class or instance)
A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.- criterion : torch criterion (class, default=torch.nn.BCEWithLogitsLoss)
Binary cross entropy loss with logits. Note that the module should return the logit of probabilities with shape (batch_size, ).
- threshold : float (default=0.5)
Probabilities above this threshold is classified as 1.
threshold
is used bypredict
andpredict_proba
for classification.- optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr : float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs : int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size : int (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.- iterator_train : torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid : torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split : None or callable (default=skorch.dataset.ValidSplit(5))
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None.- callbacks : None, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose : int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device : str, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
Attributes: - prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_ : bool
Whether the
NeuralNet
was initialized.- module_ : torch module (instance)
The instantiated module.
- criterion_ : torch criterion (instance)
The instantiated criterion.
- callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules : list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria : list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers : list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
Methods
check_data
(X, y)check_is_fitted
([attributes])Checks whether the net is initialized check_training_readiness
()Check that the net is ready to train evaluation_step
(batch[, training])Perform a forward step to produce the output used for prediction and scoring. fit
(X, y, **fit_params)See NeuralNet.fit
.fit_loop
(X[, y, epochs])The proper fit loop. forward
(X[, training, device])Gather and concatenate the output from forward call with input data. forward_iter
(X[, training, device])Yield outputs of module forward calls on each batch of data. get_all_learnable_params
()Yield the learnable parameters of all modules get_dataset
(X[, y])Get a dataset that contains the input data and is passed to the iterator. get_iterator
(dataset[, training])Get an iterator that allows to loop over the batches of the given data. get_loss
(y_pred, y_true[, X, training])Return the loss for this batch. get_params_for
(prefix)Collect and return init parameters for an attribute. get_params_for_optimizer
(prefix, …)Collect and return init parameters for an optimizer. get_split_datasets
(X[, y])Get internal train and validation datasets. get_train_step_accumulator
()Return the train step accumulator. infer
(x, **fit_params)Perform an inference step initialize
()Initializes all of its components and returns self. initialize_callbacks
()Initializes all callbacks and save the result in the callbacks_
attribute.initialize_criterion
()Initializes the criterion. initialize_history
()Initializes the history. initialize_module
()Initializes the module. initialize_optimizer
([triggered_directly])Initialize the model optimizer. initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters load_params
([f_params, f_optimizer, …])Loads the the module’s parameters, history, and optimizer, not the whole object. notify
(method_name, **cb_kwargs)Call the callback method specified in method_name
with parameters specified incb_kwargs
.on_batch_begin
(net[, batch, training])on_epoch_begin
(net[, dataset_train, …])on_epoch_end
(net[, dataset_train, dataset_valid])on_train_begin
(net[, X, y])on_train_end
(net[, X, y])partial_fit
(X[, y, classes])Fit the module. predict
(X)Where applicable, return class labels for samples in X. predict_proba
(X)Return the output of the module’s forward method as a numpy array. run_single_epoch
(iterator, training, prefix, …)Compute a single epoch of train or validation. save_params
([f_params, f_optimizer, …])Saves the module’s parameters, history, and optimizer, not the whole object. score
(X, y[, sample_weight])Return the mean accuracy on the given test data and labels. set_params
(**kwargs)Set the parameters of this class. train_step
(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step. train_step_single
(batch, **fit_params)Compute y_pred, loss value, and update net’s gradients. trim_for_prediction
()Remove all attributes not required for prediction. validation_step
(batch, **fit_params)Perform a forward step using batched data and return the resulting loss. get_default_callbacks get_params initialize_virtual_params on_batch_end on_grad_computed -
fit
(X, y, **fit_params)[source]¶ See
NeuralNet.fit
.In contrast to
NeuralNet.fit
,y
is non-optional to avoid mistakenly forgetting abouty
. However,y
can be set toNone
in case it is derived dynamically fromX
.
-
infer
(x, **fit_params)[source]¶ Perform an inference step
The first output of the
module
must be a single array that has either shape (n,) or shape (n, 1). In the latter case, the output will be reshaped to become 1-dim.
-
predict
(X)[source]¶ Where applicable, return class labels for samples in X.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()
instead.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Returns: - y_pred : numpy ndarray
-
class
skorch.classifier.
NeuralNetClassifier
(module, *args, criterion=<class 'torch.nn.modules.loss.NLLLoss'>, train_split=<skorch.dataset.ValidSplit object>, classes=None, **kwargs)[source]¶ NeuralNet for classification tasks
Use this specifically if you have a standard classification task, with input data X and target y.
In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:
>>> net = NeuralNet( ... ..., ... optimizer=torch.optimizer.SGD, ... optimizer__momentum=0.95, ...)
This way, when
optimizer
is initialized,NeuralNet
will take care of setting themomentum
parameter to 0.95.(Note that the double underscore notation in
optimizer__momentum
means that the parametermomentum
should be set on the objectoptimizer
. This is the same semantic as used by sklearn.)Furthermore, this allows to change those parameters later:
net.set_params(optimizer__momentum=0.99)
This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.
By default an
EpochTimer
,BatchScoring
(for both training and validation datasets), andPrintLog
callbacks are installed for the user’s convenience.Parameters: - module : torch module (class or instance)
A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.- criterion : torch criterion (class, default=torch.nn.NLLLoss)
Negative log likelihood loss. Note that the module should return probabilities, the log is applied during
get_loss
.- classes : None or list (default=None)
If None, the
classes_
attribute will be inferred from they
data passed tofit
. If a non-empty list is passed, that list will be returned asclasses_
. If the initial skorch behavior should be restored, i.e. raising anAttributeError
, pass an empty list.- optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr : float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs : int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size : int (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.- iterator_train : torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid : torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split : None or callable (default=skorch.dataset.ValidSplit(5))
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None.- callbacks : None, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose : int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device : str, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
Attributes: - prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_ : bool
Whether the
NeuralNet
was initialized.- module_ : torch module (instance)
The instantiated module.
- criterion_ : torch criterion (instance)
The instantiated criterion.
- callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules : list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria : list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers : list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- classes_ : array, shape (n_classes, )
A list of class labels known to the classifier.
Methods
check_data
(X, y)check_is_fitted
([attributes])Checks whether the net is initialized check_training_readiness
()Check that the net is ready to train evaluation_step
(batch[, training])Perform a forward step to produce the output used for prediction and scoring. fit
(X, y, **fit_params)See NeuralNet.fit
.fit_loop
(X[, y, epochs])The proper fit loop. forward
(X[, training, device])Gather and concatenate the output from forward call with input data. forward_iter
(X[, training, device])Yield outputs of module forward calls on each batch of data. get_all_learnable_params
()Yield the learnable parameters of all modules get_dataset
(X[, y])Get a dataset that contains the input data and is passed to the iterator. get_iterator
(dataset[, training])Get an iterator that allows to loop over the batches of the given data. get_loss
(y_pred, y_true, *args, **kwargs)Return the loss for this batch. get_params_for
(prefix)Collect and return init parameters for an attribute. get_params_for_optimizer
(prefix, …)Collect and return init parameters for an optimizer. get_split_datasets
(X[, y])Get internal train and validation datasets. get_train_step_accumulator
()Return the train step accumulator. infer
(x, **fit_params)Perform a single inference step on a batch of data. initialize
()Initializes all of its components and returns self. initialize_callbacks
()Initializes all callbacks and save the result in the callbacks_
attribute.initialize_criterion
()Initializes the criterion. initialize_history
()Initializes the history. initialize_module
()Initializes the module. initialize_optimizer
([triggered_directly])Initialize the model optimizer. initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters load_params
([f_params, f_optimizer, …])Loads the the module’s parameters, history, and optimizer, not the whole object. notify
(method_name, **cb_kwargs)Call the callback method specified in method_name
with parameters specified incb_kwargs
.on_batch_begin
(net[, batch, training])on_epoch_begin
(net[, dataset_train, …])on_epoch_end
(net[, dataset_train, dataset_valid])on_train_begin
(net[, X, y])on_train_end
(net[, X, y])partial_fit
(X[, y, classes])Fit the module. predict
(X)Where applicable, return class labels for samples in X. predict_proba
(X)Where applicable, return probability estimates for samples. run_single_epoch
(iterator, training, prefix, …)Compute a single epoch of train or validation. save_params
([f_params, f_optimizer, …])Saves the module’s parameters, history, and optimizer, not the whole object. score
(X, y[, sample_weight])Return the mean accuracy on the given test data and labels. set_params
(**kwargs)Set the parameters of this class. train_step
(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step. train_step_single
(batch, **fit_params)Compute y_pred, loss value, and update net’s gradients. trim_for_prediction
()Remove all attributes not required for prediction. validation_step
(batch, **fit_params)Perform a forward step using batched data and return the resulting loss. get_default_callbacks get_params initialize_virtual_params on_batch_end on_grad_computed -
fit
(X, y, **fit_params)[source]¶ See
NeuralNet.fit
.In contrast to
NeuralNet.fit
,y
is non-optional to avoid mistakenly forgetting abouty
. However,y
can be set toNone
in case it is derived dynamically fromX
.
-
get_loss
(y_pred, y_true, *args, **kwargs)[source]¶ Return the loss for this batch.
Parameters: - y_pred : torch tensor
Predicted target values
- y_true : torch tensor
True target values.
- X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- training : bool (default=False)
Whether train mode should be used or not.
-
predict
(X)[source]¶ Where applicable, return class labels for samples in X.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()
instead.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Returns: - y_pred : numpy ndarray
-
predict_proba
(X)[source]¶ Where applicable, return probability estimates for samples.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()
instead.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Returns: - y_proba : numpy ndarray
skorch.dataset¶
Contains custom skorch Dataset and ValidSplit.
-
class
skorch.dataset.
CVSplit
(*args, **kwargs)[source]¶ Methods
__call__
(dataset[, y, groups])Call self as a function. check_cv
(y)Resolve which cross validation strategy is used.
-
class
skorch.dataset.
Dataset
(X, y=None, length=None)[source]¶ General dataset wrapper that can be used in conjunction with PyTorch
DataLoader
.The dataset will always yield a tuple of two values, the first from the data (
X
) and the second from the target (y
). However, the target is allowed to beNone
. In that case,Dataset
will currently return a dummy tensor, sinceDataLoader
does not work withNone
s.Dataset
currently works with the following data types:- numpy
array
s - PyTorch
Tensor
s - scipy sparse CSR matrices
- pandas NDFrame
- a dictionary of the former three
- a list/tuple of the former three
Parameters: - X : see above
Everything pertaining to the input data.
- y : see above or None (default=None)
Everything pertaining to the target, if there is anything.
- length : int or None (default=None)
If not
None
, determines the length (len
) of the data. Should usually be left atNone
, in which case the length is determined by the data itself.
Methods
transform
(X, y)Additional transformations on X
andy
.-
transform
(X, y)[source]¶ Additional transformations on
X
andy
.By default, they are cast to PyTorch
Tensor
s. Override this if you want a different behavior.Note: If you use this in conjuction with PyTorch
DataLoader
, the latter will call the dataset for each row separately, which means that the incomingX
andy
each are single rows.
- numpy
-
class
skorch.dataset.
ValidSplit
(cv=5, stratified=False, random_state=None)[source]¶ Class that performs the internal train/valid split on a dataset.
The
cv
argument here works similarly to the regular sklearncv
parameter in, e.g.,GridSearchCV
. However, instead of cycling through all splits, only one fixed split (the first one) is used. To get a full cycle through the splits, don’t useNeuralNet
’s internal validation but instead the corresponding sklearn functions (e.g.cross_val_score
).We additionally support a float, similar to sklearn’s
train_test_split
.Parameters: - cv : int, float, cross-validation generator or an iterable, optional
(Refer sklearn’s User Guide for cross_validation for the various cross-validation strategies that can be used here.)
Determines the cross-validation splitting strategy. Possible inputs for cv are:
- None, to use the default 3-fold cross validation,
- integer, to specify the number of folds in a
(Stratified)KFold
, - float, to represent the proportion of the dataset to include in the validation split.
- An object to be used as a cross-validation generator.
- An iterable yielding train, validation splits.
- stratified : bool (default=False)
Whether the split should be stratified. Only works if
y
is either binary or multiclass classification.- random_state : int, RandomState instance, or None (default=None)
Control the random state in case that
(Stratified)ShuffleSplit
is used (which is when a float is passed tocv
). For more information, look at the sklearn documentation of(Stratified)ShuffleSplit
.
Methods
__call__
(dataset[, y, groups])Call self as a function. check_cv
(y)Resolve which cross validation strategy is used.
skorch.exceptions¶
Contains skorch-specific exceptions and warnings.
-
exception
skorch.exceptions.
DeviceWarning
[source]¶ A problem with a device (e.g. CUDA) was detected.
-
exception
skorch.exceptions.
NotInitializedError
[source]¶ Module is not initialized, please call the
.initialize
method or train the model by calling.fit(...)
.
-
exception
skorch.exceptions.
SkorchAttributeError
[source]¶ An attribute was set incorrectly on a skorch net.
skorch.helper¶
Helper functions and classes for users.
They should not be used in skorch directly.
-
class
skorch.helper.
DataFrameTransformer
(treat_int_as_categorical=False, float_dtype=<class 'numpy.float32'>, int_dtype=<class 'numpy.int64'>)[source]¶ Transform a DataFrame into a dict useful for working with skorch.
Transforms cardinal data to floats and categorical data to vectors of ints so that they can be embedded.
Although skorch can deal with pandas DataFrames, the default behavior is often not very useful. Use this transformer to transform the DataFrame into a dict with all float columns concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.
Your module must have a matching signature for this to work. It must accept an argument
X
for all cardinal values. Additionally, for all categorical values, it must accept an argument with the same name as the corresponding column (see example below). If you need help with the required signature, use thedescribe_signature
method of this class and pass it your data.You can choose whether you want to treat int columns the same as float columns (default) or as categorical values.
To one-hot encode categorical features, initialize their corresponding embedding layers using the identity matrix.
Parameters: - treat_int_as_categorical : bool (default=False)
Whether to treat integers as categorical values or as cardinal values, i.e. the same as floats.
- float_dtype : numpy dtype or None (default=np.float32)
The dtype to cast the cardinal values to. If None, don’t change them.
- int_dtype : numpy dtype or None (default=np.int64)
The dtype to cast the categorical values to. If None, don’t change them. If you do this, it can happen that the categorical values will have different dtypes, reflecting the number of unique categories.
Notes
The value of X will always be 2-dimensional, even if it only contains 1 column.
Examples
>>> df = pd.DataFrame({ ... 'col_floats': np.linspace(0, 1, 12), ... 'col_ints': [11, 11, 10] * 4, ... 'col_cats': ['a', 'b', 'a'] * 4, ... }) >>> # cast to category dtype to later learn embeddings >>> df['col_cats'] = df['col_cats'].astype('category') >>> y = np.asarray([0, 1, 0] * 4)
>>> class MyModule(nn.Module): ... def __init__(self): ... super().__init__() ... self.reset_params()
>>> def reset_params(self): ... self.embedding = nn.Embedding(2, 10) ... self.linear = nn.Linear(2, 10) ... self.out = nn.Linear(20, 2) ... self.nonlin = nn.Softmax(dim=-1)
>>> def forward(self, X, col_cats): ... # "X" contains the values from col_floats and col_ints ... # "col_cats" contains the values from "col_cats" ... X_lin = self.linear(X) ... X_cat = self.embedding(col_cats) ... X_concat = torch.cat((X_lin, X_cat), dim=1) ... return self.nonlin(self.out(X_concat))
>>> net = NeuralNetClassifier(MyModule) >>> pipe = Pipeline([ ... ('transform', DataFrameTransformer()), ... ('net', net), ... ]) >>> pipe.fit(df, y)
Methods
describe_signature
(df)Describe the signature required for the given data. fit
(df[, y])fit_transform
(X[, y])Fit to data, then transform it. get_params
([deep])Get parameters for this estimator. set_params
(**params)Set the parameters of this estimator. transform
(df)Transform DataFrame to become a dict that works well with skorch. -
describe_signature
(df)[source]¶ Describe the signature required for the given data.
Pass the DataFrame to receive a description of the signature required for the module’s forward method. The description consists of three parts:
1. The names of the arguments that the forward method needs. 2. The dtypes of the torch tensors passed to forward. 3. The number of input units that are required for the corresponding argument. For the float parameter, this is just the number of dimensions of the tensor. For categorical parameters, it is the number of unique elements.
Returns: - signature : dict
Returns a dict with each key corresponding to one key required for the forward method. The values are dictionaries of two elements. The key “dtype” describes the torch dtype of the resulting tensor, the key “input_units” describes the required number of input units.
-
pd
= <module 'pandas' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/site-packages/pandas/__init__.py'>[source]¶
-
transform
(df)[source]¶ Transform DataFrame to become a dict that works well with skorch.
Parameters: - df : pd.DataFrame
Incoming DataFrame.
Returns: - X_dict: dict
Dictionary with all floats concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.
-
class
skorch.helper.
SliceDataset
(dataset, idx=0, indices=None)[source]¶ Helper class that wraps a torch dataset to make it work with sklearn.
Sometimes, sklearn will touch the input data, e.g. when splitting the data for a grid search. This will fail when the input data is a torch dataset. To prevent this, use this wrapper class for your dataset.
Note: This class will only return the X value by default (i.e. the first value returned by indexing the original dataset). Sklearn, and hence skorch, always require 2 values, X and y. Therefore, you still need to provide the y data separately.
Note: This class behaves similarly to a PyTorch
Subset
when it is indexed by a slice or numpy array: It will return anotherSliceDataset
that references the subset instead of the actual values. Only when it is indexed by an int does it return the actual values. The reason for this is to avoid loading all data into memory when sklearn, for instance, creates a train/validation split on the dataset. Data will only be loaded in batches during the fit loop.Parameters: - dataset : torch.utils.data.Dataset
A valid torch dataset.
- idx : int (default=0)
Indicates which element of the dataset should be returned. Typically, the dataset returns both X and y values. SliceDataset can only return 1 value. If you want to get X, choose idx=0 (default), if you want y, choose idx=1.
- indices : list, np.ndarray, or None (default=None)
If you only want to return a subset of the dataset, indicate which subset that is by passing this argument. Typically, this can be left to be None, which returns all the data. See also
Subset
.
Examples
>>> X = MyCustomDataset() >>> search = GridSearchCV(net, params, ...) >>> search.fit(X, y) # raises error >>> ds = SliceDataset(X) >>> search.fit(ds, y) # works
Attributes: - shape
Methods
count
(value)index
(value, [start, [stop]])Raises ValueError if the value is not present. transform
(data)Additional transformations on data
.-
transform
(data)[source]¶ Additional transformations on
data
.Note: If you use this in conjuction with PyTorch
DataLoader
, the latter will call the dataset for each row separately, which means that the incomingdata
is a single rows.
-
class
skorch.helper.
SliceDict
(**kwargs)[source]¶ Wrapper for Python dict that makes it sliceable across values.
Use this if your input data is a dictionary and you have problems with sklearn not being able to slice it. Wrap your dict with SliceDict and it should usually work.
Note:
- SliceDict cannot be indexed by integers, if you want one row, say row 3, use [3:4].
- SliceDict accepts numpy arrays and torch tensors as values.
Examples
>>> X = {'key0': val0, 'key1': val1} >>> search = GridSearchCV(net, params, ...) >>> search.fit(X, y) # raises error >>> Xs = SliceDict(key0=val0, key1=val1) # or Xs = SliceDict(**X) >>> search.fit(Xs, y) # works
Attributes: - shape
Methods
clear
()copy
()fromkeys
(*args, **kwargs)fromkeys method makes no sense with SliceDict and is thus not supported. get
($self, key[, default])Return the value for key if key is in the dictionary, else default. items
()keys
()pop
(k[,d])If key is not found, d is returned if given, otherwise KeyError is raised popitem
()2-tuple; but raise KeyError if D is empty. setdefault
($self, key[, default])Insert key with a value of default if key is not in the dictionary. update
([E, ]**F)If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k] values
()
skorch.hf¶
Classes to work with Hugging Face ecosystem (https://huggingface.co/)
E.g. transformers or tokenizers
This module should be treated as a leaf node in the dependency tree, i.e. no other skorch modules should depend on these classes or import from here. Even so, don’t import any Hugging Face libraries on the root level because skorch should not depend on them.
-
class
skorch.hf.
AccelerateMixin
(*args, accelerator, device=None, callbacks__print_log__sink='auto', **kwargs)[source]¶ Mixin class to add support for Hugging Face accelerate
This is an experimental feature.
Use this mixin class with one of the neural net classes (e.g.
NeuralNet
,NeuralNetClassifier
, orNeuralNetRegressor
) and pass an instance ofAccelerator
for mixed precision, multi-GPU, or TPU training.Install the accelerate library using:
skorch does not itself provide any facilities to enable these training features. A lot of them can still be implemented by the user with a little bit of extra work but it can be a daunting task. That is why this helper class was added: Using this mixin in conjunction with the accelerate library should cover a lot of common use cases.
Note
Under the hood, accelerate uses
GradScaler
, which does not support passing the training step as a closure. Therefore, if your optimizer requires that (e.g.torch.optim.LBFGS
), you cannot use accelerate.Warning
Since accelerate is still quite young and backwards compatiblity breaking features might be added, we treat its integration as an experimental feature. When accelerate’s API stabilizes, we will consider adding it to skorch proper.
Also, models accelerated this way cannot be pickled. If you need to save and load the net, either use
skorch.net.NeuralNet.save_params()
andskorch.net.NeuralNet.load_params()
or don’t useaccelerate
.Parameters: - accelerator : accelerate.Accelerator
In addition to the usual parameters, pass an instance of
accelerate.Accelerator
with the desired settings.- device : str, torch.device, or None (default=None)
The compute device to be used. When using accelerate, it is recommended to leave device handling to accelerate. Therefore, it is best to leave this argument to be None, which means that skorch does not set the device.
- callbacks__print_log__sink : ‘auto’ or callable
If ‘auto’, uses the
print
function of the accelerator, if it has one. This avoids printing the same output multiple times when training concurrently on multiple machines. If the accelerator does not have aprint
function, use Python’sprint
function instead.
Examples
>>> from skorch import NeuralNetClassifier >>> from skorch.hf import AccelerateMixin >>> from accelerate import Accelerator >>> >>> class AcceleratedNet(AccelerateMixin, NeuralNetClassifier): ... '''NeuralNetClassifier with accelerate support''' >>> >>> accelerator = Accelerator(...) >>> # you may pass gradient_accumulation_steps to enable grad accumulation >>> net = AcceleratedNet(MyModule, accelerator=accelerator) >>> net.fit(X, y)
The same approach works with all the other skorch net classes.
Methods
on_train_end
(net[, X, y])get_iterator initialize_callbacks train_step train_step_single
-
class
skorch.hf.
HuggingfacePretrainedTokenizer
(tokenizer, train=False, max_length=256, return_tensors='pt', return_attention_mask=True, return_token_type_ids=False, return_length=False, verbose=0, vocab_size=None)[source]¶ Wraps a pretrained Huggingface tokenizer to work as an sklearn transformer
From the tokenizers docs:
🤗 Tokenizers provides an implementation of today’s most used tokenizers, with a focus on performance and versatility.
Use pretrained Hugging Face tokenizers in an sklearn compatible transformer.
Parameters: - tokenizer : str or os.PathLike or transformers.PreTrainedTokenizerFast
If a string, the model id of a predefined tokenizer hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like bert-base-uncased, or namespaced under a user or organization name, like dbmdz/bert-base-german-cased. If a path, A path to a directory containing vocabulary files required by the tokenizer, e.g., ./my_model_directory/. Else, should be an instantiated
PreTrainedTokenizerFast
.- train : bool (default=False)
Whether to use the pre-trained tokenizer directly as is or to retrain it on your data. If you just want to use the pre-trained tokenizer without further modification, leave this parameter as False. However, if you want to fit the tokenizer on your own data (completely from scratch, forgetting what it has learned previously), set this argument to True. The latter option is useful if you want to use the same hyper-parameters as the pre-trained tokenizer but want the vocabulary to be fitted to your dataset. The vocabulary size of this new tokenizer can be set explicitly by passing the
vocab_size
argument.- max_length : int (default=256)
Maximum number of tokens used per sequence.
- return_tensors : one of None, str, ‘pt’, ‘np’, ‘tf’ (default=’pt’)
What type of result values to return. By default, return a padded and truncated (to
max_length
) PyTorch Tensor. Similarly, ‘np’ results in a padded and truncated numpy array. Tensorflow tensors are not supported officially supported but should also work. If None or str, return a list of lists instead. These lists are not padded or truncated, thus each row may have different numbers of elements.- return_attention_mask : bool (default=True)
Whether to return the attention mask.
- return_token_type_ids : bool (default=False)
Whether to return the token type ids.
- return_length : bool (default=False)
Whether to return the length of the encoded inputs.
- pad_token : str (default=’[PAD]’)
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms.
- vocab_size : int or None (default=None)
Change this parameter only if you use
train=True
. In that case, this parameter will determine the vocabulary size of the newly trained tokenizer. If you settrain=True
but leave this parameter as None, the same vocabulary size as the one from the initial toknizer will be used.- verbose : int (default=0)
Whether the tokenizer should print more information and warnings.
Examples
>>> from skorch.hf import HuggingfacePretrainedTokenizer >>> # pass the model name to be downloaded >>> hf_tokenizer = HuggingfacePretrainedTokenizer('bert-base-uncased') >>> data = ['hello there', 'this is a text'] >>> hf_tokenizer.fit(data) # only loads the model >>> hf_tokenizer.transform(data)
>>> # pass pretrained tokenizer as object >>> my_tokenizer = ... >>> hf_tokenizer = HuggingfacePretrainedTokenizer(my_tokenizer) >>> hf_tokenizer.fit(data) >>> hf_tokenizer.transform(data)
>>> # use hyper params from pretrained tokenizer to fit on own data >>> hf_tokenizer = HuggingfacePretrainedTokenizer( ... 'bert-base-uncased', train=True, vocab_size=12345) >>> data = ... >>> hf_tokenizer.fit(data) # fits new tokenizer on data >>> hf_tokenizer.transform(data)
Attributes: - vocabulary_ : dict
A mapping of terms to feature indices.
- fast_tokenizer_ : transformers.PreTrainedTokenizerFast
If you want to extract the Hugging Face tokenizer to use it without skorch, use this attribute.
- .. _tokenizers: https://huggingface.co/docs/tokenizers/python/latest/index.html
Methods
fit
(X[, y])Load the pretrained tokenizer fit_transform
(X[, y])Fit to data, then transform it. get_feature_names_out
([input_features])Array mapping from feature integer indices to feature name. get_params
([deep])Get parameters for this estimator. inverse_transform
(X)Decode encodings back into strings set_params
(**params)Set the parameters of this estimator. tokenize
(X, **kwargs)Convenience method to use the trained tokenizer for tokenization transform
(X)Transform the given data
-
class
skorch.hf.
HuggingfaceTokenizer
(tokenizer, model=None, trainer='auto', normalizer=None, pre_tokenizer=None, post_processor=None, max_length=256, return_tensors='pt', return_attention_mask=True, return_token_type_ids=False, return_length=False, pad_token='[PAD]', verbose=0, **kwargs)[source]¶ Wraps a Hugging Face tokenizer to work as an sklearn transformer
From the tokenizers docs:
🤗 Tokenizers provides an implementation of today’s most used tokenizers, with a focus on performance and versatility.
Use of Hugging Face tokenizers for training on custom data using an sklearn compatible API.
Parameters: - tokenizer : tokenizers.Tokenizer
The tokenizer to train.
- model : tokenizers.models.Model
The model represents the actual tokenization algorithm, e.g.
BPE
.- trainer : tokenizers.trainers.Trainer or ‘auto’ (default=’auto’)
Class responsible for training the tokenizer. If ‘auto’, the correct trainer will be inferred from the used model using
model.get_trainer()
.- normalizer : tokenizers.normalizers.Normalizer or None (default=None)
Optional normalizer, e.g. for casting the text to lowercase.
- pre_tokenizer : tokenizers.pre_tokenizers.PreTokenizer or None (default=None)
Optional pre-tokenization, e.g. splitting on space.
- post_processor : tokenizers.processors.PostProcessor
Optional post-processor, mostly used to add special tokens for BERT etc.
- max_length : int (default=256)
Maximum number of tokens used per sequence.
- return_tensors : one of None, str, ‘pt’, ‘np’, ‘tf’ (default=’pt’)
What type of result values to return. By default, return a padded and truncated (to
max_length
) PyTorch Tensor. Similarly, ‘np’ results in a padded and truncated numpy array. Tensorflow tensors are not supported officially supported but should also work. If None or str, return a list of lists instead. These lists are not padded or truncated, thus each row may have different numbers of elements.- return_attention_mask : bool (default=True)
Whether to return the attention mask.
- return_token_type_ids : bool (default=False)
Whether to return the token type ids.
- return_length : bool (default=False)
Whether to return the length of the encoded inputs.
- pad_token : str (default=’[PAD]’)
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms.
- verbose : int (default=0)
Whether the tokenizer should print more information and warnings.
Examples
>>> # train a BERT tokenizer from scratch >>> from tokenizers import Tokenizer >>> from tokenizers.models import WordPiece >>> from tokenizers import normalizers >>> from tokenizers.normalizers import Lowercase, NFD, StripAccents >>> from tokenizers.pre_tokenizers import Whitespace >>> from tokenizers.processors import TemplateProcessing >>> bert_tokenizer = Tokenizer(WordPiece(unk_token="[UNK]")) >>> normalizer = normalizers.Sequence([NFD(), Lowercase(), StripAccents()]) >>> pre_tokenizer = Whitespace() >>> post_processor = TemplateProcessing( ... single="[CLS] $A [SEP]", ... pair="[CLS] $A [SEP] $B:1 [SEP]:1", ... special_tokens=[ ... ("[CLS]", 1), ... ("[SEP]", 2), ... ], ... ) >>> from skorch.hf import HuggingfaceTokenizer >>> hf_tokenizer = HuggingfaceTokenizer( ... tokenizer=bert_tokenizer, ... pre_tokenizer=pre_tokenizer, ... post_processor=post_processor, ... trainer__vocab_size=30522, ... trainer__special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"], ... ) >>> data = ['hello there', 'this is a text'] >>> hf_tokenizer.fit(data) >>> hf_tokenizer.transform(data)
In general, you can pass both initialized objects and uninitialized objects as parameters:
# initialized HuggingfaceTokenizer(tokenizer=Tokenizer(model=WordPiece())) # uninitialized HuggingfaceTokenizer(tokenizer=Tokenizer, model=WordPiece)
Both approaches work equally well and allow you to, for instance, grid search on the tokenizer parameters. However, it is recommended not to pass an initialized trainer. This is because the trainer will then be saved as an attribute on the object, which can be wasteful. Instead, it is best to leave the default
trainer='auto'
, which results in the trainer being derived from the model.Note
If you want to train the
HuggingfaceTokenizer
in parallel (e.g. during a grid search), you should probably set the environment variableTOKENIZERS_PARALLELISM=false
. Otherwise, you may experience slow downs or deadlocks.Attributes: - vocabulary_ : dict
A mapping of terms to feature indices.
- fast_tokenizer_ : transformers.PreTrainedTokenizerFast
If you want to extract the Hugging Face tokenizer to use it without skorch, use this attribute.
- .. _tokenizers: https://huggingface.co/docs/tokenizers/python/latest/index.html
Methods
fit
(X[, y])Train the tokenizer on given data fit_transform
(X[, y])Fit to data, then transform it. get_feature_names_out
([input_features])Array mapping from feature integer indices to feature name. get_params
([deep])Get parameters for this estimator. get_params_for
(prefix)Collect and return init parameters for an attribute. initialize
()Initialize the individual tokenizer components initialize_trainer
()Initialize the trainer initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters inverse_transform
(X)Decode encodings back into strings set_params
(**kwargs)Set the parameters of this class. tokenize
(X, **kwargs)Convenience method to use the trained tokenizer for tokenization transform
(X)Transform the given data initialize_model initialize_normalizer initialize_post_processor initialize_pre_tokenizer initialize_tokenizer -
fit
(X, y=None, **fit_params)[source]¶ Train the tokenizer on given data
Parameters: - X : iterable of str
A list/array of strings or an iterable which generates either strings.
- y : None
This parameter is ignored.
- fit_params : dict
This parameter is ignored.
Returns: - self : HuggingfaceTokenizer
The fitted instance of the tokenizer.
-
get_params
(deep=False)[source]¶ Get parameters for this estimator.
Parameters: - deep : bool, default=True
If True, will return the parameters for this estimator and contained subobjects that are estimators.
Returns: - params : dict
Parameter names mapped to their values.
-
initialize_trainer
()[source]¶ Initialize the trainer
Infer the trainer type from the model if necessary.
-
initialized_instance
(instance_or_cls, kwargs)[source]¶ Return an instance initialized with the given parameters
This is a helper method that deals with several possibilities for a component that might need to be initialized:
- It is already an instance that’s good to go
- It is an instance but it needs to be re-initialized
- It’s not an instance and needs to be initialized
For the majority of use cases, this comes down to just comes down to just initializing the class with its arguments.
Parameters: - instance_or_cls
The instance or class or callable to be initialized.
- kwargs : dict
The keyword arguments to initialize the instance or class. Can be an empty dict.
Returns: - instance
The initialized component.
skorch.history¶
Contains history class and helper functions.
-
class
skorch.history.
History
[source]¶ History contains the information about the training history of a
NeuralNet
, facilitating some of the more common tasks that are occur during training.When you want to log certain information during training (say, a particular score or the norm of the gradients), you should write them to the net’s history object.
It is basically a list of dicts for each epoch, that, again, contains a list of dicts for each batch. For convenience, it has enhanced slicing notation and some methods to write new items.
To access items from history, you may pass a tuple of up to four items:
- Slices along the epochs.
- Selects columns from history epochs, may be a single one or a tuple of column names.
- Slices along the batches.
- Selects columns from history batchs, may be a single one or a tuple of column names.
You may use a combination of the four items.
If you select columns that are not present in all epochs/batches, only those epochs/batches are chosen that contain said columns. If this set is empty, a
KeyError
is raised.Examples
>>> # ACCESSING ITEMS >>> # history of a fitted neural net >>> history = net.history >>> # get current epoch, a dict >>> history[-1] >>> # get train losses from all epochs, a list of floats >>> history[:, 'train_loss'] >>> # get train and valid losses from all epochs, a list of tuples >>> history[:, ('train_loss', 'valid_loss')] >>> # get current batches, a list of dicts >>> history[-1, 'batches'] >>> # get latest batch, a dict >>> history[-1, 'batches', -1] >>> # get train losses from current batch, a list of floats >>> history[-1, 'batches', :, 'train_loss'] >>> # get train and valid losses from current batch, a list of tuples >>> history[-1, 'batches', :, ('train_loss', 'valid_loss')]
>>> # WRITING ITEMS >>> # add new epoch row >>> history.new_epoch() >>> # add an entry to current epoch >>> history.record('my-score', 123) >>> # add a batch row to the current epoch >>> history.new_batch() >>> # add an entry to the current batch >>> history.record_batch('my-batch-score', 456) >>> # overwrite entry of current batch >>> history.record_batch('my-batch-score', 789)
Methods
append
($self, object, /)Append object to the end of the list. clear
($self, /)Remove all items from list. copy
($self, /)Return a shallow copy of the list. count
($self, value, /)Return number of occurrences of value. extend
($self, iterable, /)Extend list by appending elements from the iterable. from_file
(f)Load the history of a NeuralNet
from a json file.index
($self, value[, start, stop])Return first index of value. insert
($self, index, object, /)Insert object before index. new_batch
()Register a new batch row for the current epoch. new_epoch
()Register a new epoch row. pop
($self[, index])Remove and return item at index (default last). record
(attr, value)Add a new value to the given column for the current epoch. record_batch
(attr, value)Add a new value to the given column for the current batch. remove
($self, value, /)Remove first occurrence of value. reverse
($self, /)Reverse IN PLACE. sort
($self, /, *[, key, reverse])Stable sort IN PLACE. to_file
(f)Saves the history as a json file. to_list
()Return history object as a list. -
classmethod
from_file
(f)[source]¶ Load the history of a
NeuralNet
from a json file.Parameters: - f : file-like object or str
skorch.net¶
Neural net base class
This is the most flexible class, not making assumptions on the kind of task being peformed. Subclass this to create more specialized and sklearn-conforming classes like NeuralNetClassifier.
-
class
skorch.net.
NeuralNet
(module, criterion, optimizer=<class 'torch.optim.sgd.SGD'>, lr=0.01, max_epochs=10, batch_size=128, iterator_train=<class 'torch.utils.data.dataloader.DataLoader'>, iterator_valid=<class 'torch.utils.data.dataloader.DataLoader'>, dataset=<class 'skorch.dataset.Dataset'>, train_split=<skorch.dataset.ValidSplit object>, callbacks=None, predict_nonlinearity='auto', warm_start=False, verbose=1, device='cpu', **kwargs)[source]¶ NeuralNet base class.
The base class covers more generic cases. Depending on your use case, you might want to use
NeuralNetClassifier
orNeuralNetRegressor
.In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:
>>> net = NeuralNet( ... ..., ... optimizer=torch.optimizer.SGD, ... optimizer__momentum=0.95, ...)
This way, when
optimizer
is initialized,NeuralNet
will take care of setting themomentum
parameter to 0.95.(Note that the double underscore notation in
optimizer__momentum
means that the parametermomentum
should be set on the objectoptimizer
. This is the same semantic as used by sklearn.)Furthermore, this allows to change those parameters later:
net.set_params(optimizer__momentum=0.99)
This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.
By default an
EpochTimer
,BatchScoring
(for both training and validation datasets), andPrintLog
callbacks are installed for the user’s convenience.Parameters: - module : torch module (class or instance)
A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.- criterion : torch criterion (class)
The uninitialized criterion (loss) used to optimize the module.
- optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr : float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs : int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size : int (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.- iterator_train : torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid : torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split : None or callable (default=skorch.dataset.ValidSplit(5))
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None.- callbacks : None, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose : int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device : str, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
Attributes: - prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_ : bool
Whether the
NeuralNet
was initialized.- module_ : torch module (instance)
The instantiated module.
- criterion_ : torch criterion (instance)
The instantiated criterion.
- callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules : list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria : list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers : list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
Methods
check_is_fitted
([attributes])Checks whether the net is initialized check_training_readiness
()Check that the net is ready to train evaluation_step
(batch[, training])Perform a forward step to produce the output used for prediction and scoring. fit
(X[, y])Initialize and fit the module. fit_loop
(X[, y, epochs])The proper fit loop. forward
(X[, training, device])Gather and concatenate the output from forward call with input data. forward_iter
(X[, training, device])Yield outputs of module forward calls on each batch of data. get_all_learnable_params
()Yield the learnable parameters of all modules get_dataset
(X[, y])Get a dataset that contains the input data and is passed to the iterator. get_iterator
(dataset[, training])Get an iterator that allows to loop over the batches of the given data. get_loss
(y_pred, y_true[, X, training])Return the loss for this batch. get_params_for
(prefix)Collect and return init parameters for an attribute. get_params_for_optimizer
(prefix, …)Collect and return init parameters for an optimizer. get_split_datasets
(X[, y])Get internal train and validation datasets. get_train_step_accumulator
()Return the train step accumulator. infer
(x, **fit_params)Perform a single inference step on a batch of data. initialize
()Initializes all of its components and returns self. initialize_callbacks
()Initializes all callbacks and save the result in the callbacks_
attribute.initialize_criterion
()Initializes the criterion. initialize_history
()Initializes the history. initialize_module
()Initializes the module. initialize_optimizer
([triggered_directly])Initialize the model optimizer. initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters load_params
([f_params, f_optimizer, …])Loads the the module’s parameters, history, and optimizer, not the whole object. notify
(method_name, **cb_kwargs)Call the callback method specified in method_name
with parameters specified incb_kwargs
.on_batch_begin
(net[, batch, training])on_epoch_begin
(net[, dataset_train, …])on_epoch_end
(net[, dataset_train, dataset_valid])on_train_begin
(net[, X, y])on_train_end
(net[, X, y])partial_fit
(X[, y, classes])Fit the module. predict
(X)Where applicable, return class labels for samples in X. predict_proba
(X)Return the output of the module’s forward method as a numpy array. run_single_epoch
(iterator, training, prefix, …)Compute a single epoch of train or validation. save_params
([f_params, f_optimizer, …])Saves the module’s parameters, history, and optimizer, not the whole object. set_params
(**kwargs)Set the parameters of this class. train_step
(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step. train_step_single
(batch, **fit_params)Compute y_pred, loss value, and update net’s gradients. trim_for_prediction
()Remove all attributes not required for prediction. validation_step
(batch, **fit_params)Perform a forward step using batched data and return the resulting loss. check_data get_default_callbacks get_params initialize_virtual_params on_batch_end on_grad_computed -
check_is_fitted
(attributes=None, *args, **kwargs)[source]¶ Checks whether the net is initialized
Parameters: - attributes : iterable of str or None (default=None)
All the attributes that are strictly required of a fitted net. By default, this is the module_ attribute.
- Other arguments as in
- ``sklearn.utils.validation.check_is_fitted``.
Raises: - skorch.exceptions.NotInitializedError
When the given attributes are not present.
-
evaluation_step
(batch, training=False)[source]¶ Perform a forward step to produce the output used for prediction and scoring.
Therefore, the module is set to evaluation mode by default beforehand which can be overridden to re-enable features like dropout by setting
training=True
.Parameters: - batch
A single batch returned by the data loader.
- training : bool (default=False)
Whether to set the module to train mode or not.
Returns: - y_infer
The prediction generated by the module.
-
fit
(X, y=None, **fit_params)[source]¶ Initialize and fit the module.
If the module was already initialized, by calling fit, the module will be re-initialized (unless
warm_start
is True).Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- y : target data, compatible with skorch.dataset.Dataset
The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
-
fit_loop
(X, y=None, epochs=None, **fit_params)[source]¶ The proper fit loop.
Contains the logic of what actually happens during the fit loop.
Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- y : target data, compatible with skorch.dataset.Dataset
The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.- epochs : int or None (default=None)
If int, train for this number of epochs; if None, use
self.max_epochs
.- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
-
forward
(X, training=False, device='cpu')[source]¶ Gather and concatenate the output from forward call with input data.
The outputs from
self.module_.forward
are gathered on the compute device specified bydevice
and then concatenated using PyTorchcat()
. If multiple outputs are returned byself.module_.forward
, each one of them must be able to be concatenated this way.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- training : bool (default=False)
Whether to set the module to train mode or not.
- device : string (default=’cpu’)
The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.
Returns: - y_infer : torch tensor
The result from the forward step.
-
forward_iter
(X, training=False, device='cpu')[source]¶ Yield outputs of module forward calls on each batch of data. The storage device of the yielded tensors is determined by the
device
parameter.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- training : bool (default=False)
Whether to set the module to train mode or not.
- device : string (default=’cpu’)
The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.
Yields: - yp : torch tensor
Result from a forward call on an individual batch.
-
get_all_learnable_params
()[source]¶ Yield the learnable parameters of all modules
Typically, this will yield the
named_parameters
of the standard module of the net. However, if you add custom modules or if your criterion has learnable parameters, these are returned as well.If you want your optimizer to only update the parameters of some but not all modules, you should override
initialize_module()
and match the corresponding modules and optimizers there:class MyNet(NeuralNet): def initialize_optimizer(self, *args, **kwargs): # first initialize the normal optimizer named_params = self.module_.named_parameters() args, kwargs = self.get_params_for_optimizer('optimizer', named_params) self.optimizer_ = self.optimizer(*args, **kwargs) # next add an another optimizer called 'optimizer2_' that is # only responsible for training 'module2_' named_params = self.module2_.named_parameters() args, kwargs = self.get_params_for_optimizer('optimizer2', named_params) self.optimizer2_ = torch.optim.SGD(*args, **kwargs) return self
Yields: - named_parameters : generator of parameter name and parameter
A generator over all module parameters, yielding both the name of the parameter as well as the parameter itself. Use this, for instance, to pass the named parameters to
get_params_for_optimizer()
.
-
get_dataset
(X, y=None)[source]¶ Get a dataset that contains the input data and is passed to the iterator.
Override this if you want to initialize your dataset differently.
Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- y : target data, compatible with skorch.dataset.Dataset
The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.
Returns: - dataset
The initialized dataset.
-
get_iterator
(dataset, training=False)[source]¶ Get an iterator that allows to loop over the batches of the given data.
If
self.iterator_train__batch_size
and/orself.iterator_test__batch_size
are not set, useself.batch_size
instead.Parameters: - dataset : torch Dataset (default=skorch.dataset.Dataset)
Usually,
self.dataset
, initialized with the corresponding data, is passed toget_iterator
.- training : bool (default=False)
Whether to use
iterator_train
oriterator_test
.
Returns: - iterator
An instantiated iterator that allows to loop over the mini-batches.
-
get_loss
(y_pred, y_true, X=None, training=False)[source]¶ Return the loss for this batch.
Parameters: - y_pred : torch tensor
Predicted target values
- y_true : torch tensor
True target values.
- X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- training : bool (default=False)
Whether train mode should be used or not.
-
get_params_for
(prefix)[source]¶ Collect and return init parameters for an attribute.
Attributes could be, for instance, pytorch modules, criteria, or data loaders (for optimizers, use
get_params_for_optimizer()
instead). Use the returned arguments to initialize the given attribute like this:# inside initialize_module method kwargs = self.get_params_for('module') self.module_ = self.module(**kwargs)
Proceed analogously for the criterion etc.
The reason to use this method is so that it’s possible to change the init parameters with
set_params()
, which in turn makes grid search and other similar things work.Note that in general, as a user, you never have to deal with this method because
initialize_module()
etc. are already taking care of this. You only need to deal with this if you overrideinitialize_module()
(or similar methods) because you have some custom code that requires it.Parameters: - prefix : str
The name of the attribute whose arguments should be returned. E.g. for the module, it should be
'module'
.
Returns: - kwargs : dict
Keyword arguments to be used as init parameters.
-
get_params_for_optimizer
(prefix, named_parameters)[source]¶ Collect and return init parameters for an optimizer.
Parse kwargs configuration for the optimizer identified by the given prefix. Supports param group assignment using wildcards:
optimizer__lr=0.05, optimizer__param_groups=[ ('rnn*.period', {'lr': 0.3, 'momentum': 0}), ('rnn0', {'lr': 0.1}), ]
Generally, use this method like this:
# inside initialize_optimizer method named_params = self.module_.named_parameters() pgroups, kwargs = self.get_params_for_optimizer('optimizer', named_params) if 'lr' not in kwargs: kwargs['lr'] = self.lr self.optimizer_ = self.optimizer(*pgroups, **kwargs)
The reason to use this method is so that it’s possible to change the init parameters with
set_params()
, which in turn makes grid search and other similar things work.Note that in general, as a user, you never have to deal with this method because
initialize_optimizer()
is already taking care of this. You only need to deal with this if you overrideinitialize_optimizer()
because you have some custom code that requires it.Parameters: - prefix : str
The name of the optimizer whose arguments should be returned. Typically, this should just be
'optimizer'
. There can be exceptions, however, e.g. if you want to use more than one optimizer.- named_parameters : iterator
Iterator over the parameters of the module that is intended to be optimized. It’s the return value of
my_module.named_parameters()
.
Returns: - args : tuple
All positional arguments for this optimizer (right now only one, the parameter groups).
- kwargs : dict
All other parameters for this optimizer, e.g. the learning rate.
-
get_split_datasets
(X, y=None, **fit_params)[source]¶ Get internal train and validation datasets.
The validation dataset can be None if
self.train_split
is set to None; then internal validation will be skipped.Override this if you want to change how the net splits incoming data into train and validation part.
Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- y : target data, compatible with skorch.dataset.Dataset
The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.- **fit_params : dict
Additional parameters passed to the
self.train_split
call.
Returns: - dataset_train
The initialized training dataset.
- dataset_valid
The initialized validation dataset or None
-
get_train_step_accumulator
()[source]¶ Return the train step accumulator.
By default, the accumulator stores and retrieves the first value from the optimizer call. Most optimizers make only one call, so first value is at the same time the only value.
In case of some optimizers, e.g. LBFGS,
train_step_calc_gradient
is called multiple times, as the loss function is evaluated multiple times per optimizer call. If you don’t want to return the first value in that case, override this method to return your custom accumulator.
-
infer
(x, **fit_params)[source]¶ Perform a single inference step on a batch of data.
Parameters: - x : input data
A batch of the input data.
- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
-
initialize_callbacks
()[source]¶ Initializes all callbacks and save the result in the
callbacks_
attribute.Both
default_callbacks
andcallbacks
are used (in that order). Callbacks may either be initialized or not, and if they don’t have a name, the name is inferred from the class name. Theinitialize
method is called on all callbacks.The final result will be a list of tuples, where each tuple consists of a name and an initialized callback. If names are not unique, a ValueError is raised.
-
initialize_criterion
()[source]¶ Initializes the criterion.
If the criterion is already initialized and no parameter was changed, it will be left as is.
-
initialize_module
()[source]¶ Initializes the module.
If the module is already initialized and no parameter was changed, it will be left as is.
-
initialize_optimizer
(triggered_directly=None)[source]¶ Initialize the model optimizer. If
self.optimizer__lr
is not set, useself.lr
instead.Parameters: - triggered_directly
Deprecated, don’t use it anymore.
-
initialized_instance
(instance_or_cls, kwargs)[source]¶ Return an instance initialized with the given parameters
This is a helper method that deals with several possibilities for a component that might need to be initialized:
- It is already an instance that’s good to go
- It is an instance but it needs to be re-initialized
- It’s not an instance and needs to be initialized
For the majority of use cases, this comes down to just comes down to just initializing the class with its arguments.
Parameters: - instance_or_cls
The instance or class or callable to be initialized, e.g.
self.module
.- kwargs : dict
The keyword arguments to initialize the instance or class. Can be an empty dict.
Returns: - instance
The initialized component.
-
load_params
(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, checkpoint=None, **kwargs)[source]¶ Loads the the module’s parameters, history, and optimizer, not the whole object.
To save and load the whole object, use pickle.
f_params
,f_optimizer
, etc. uses PyTorch’sload()
.If you’ve created a custom module, e.g.
net.mymodule_
, you can save that as well by passingf_mymodule
.Parameters: - f_params : file-like object, str, None (default=None)
Path of module parameters. Pass
None
to not load.- f_optimizer : file-like object, str, None (default=None)
Path of optimizer. Pass
None
to not load.- f_criterion : file-like object, str, None (default=None)
Path of criterion. Pass
None
to not save- f_history : file-like object, str, None (default=None)
Path to history. Pass
None
to not load.- checkpoint :
Checkpoint
, None (default=None) Checkpoint to load params from. If a checkpoint and a
f_*
path is passed in, thef_*
will be loaded. PassNone
to not load.
Examples
>>> before = NeuralNetClassifier(mymodule) >>> before.save_params(f_params='model.pkl', >>> f_optimizer='optimizer.pkl', >>> f_history='history.json') >>> after = NeuralNetClassifier(mymodule).initialize() >>> after.load_params(f_params='model.pkl', >>> f_optimizer='optimizer.pkl', >>> f_history='history.json')
-
notify
(method_name, **cb_kwargs)[source]¶ Call the callback method specified in
method_name
with parameters specified incb_kwargs
.Method names can be one of: * on_train_begin * on_train_end * on_epoch_begin * on_epoch_end * on_batch_begin * on_batch_end
-
partial_fit
(X, y=None, classes=None, **fit_params)[source]¶ Fit the module.
If the module is initialized, it is not re-initialized, which means that this method should be used if you want to continue training a model (warm start).
Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- y : target data, compatible with skorch.dataset.Dataset
The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.- classes : array, sahpe (n_classes,)
Solely for sklearn compatibility, currently unused.
- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
-
predict
(X)[source]¶ Where applicable, return class labels for samples in X.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()
instead.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Returns: - y_pred : numpy ndarray
-
predict_proba
(X)[source]¶ Return the output of the module’s forward method as a numpy array.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()
instead.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Returns: - y_proba : numpy ndarray
-
run_single_epoch
(iterator, training, prefix, step_fn, **fit_params)[source]¶ Compute a single epoch of train or validation.
Parameters: - iterator : torch DataLoader or None
The initialized
DataLoader
to loop over. If None, skip this step.- training : bool
Whether to set the module to train mode or not.
- prefix : str
Prefix to use when saving to the history.
- step_fn : callable
Function to call for each batch.
- **fit_params : dict
Additional parameters passed to the
step_fn
.
-
save_params
(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, **kwargs)[source]¶ Saves the module’s parameters, history, and optimizer, not the whole object.
To save the whole object, use pickle. This is necessary when you need additional learned attributes on the net, e.g. the
classes_
attribute onskorch.classifier.NeuralNetClassifier
.f_params
,f_optimizer
, etc. use PyTorch’ssave()
.If you’ve created a custom module, e.g.
net.mymodule_
, you can save that as well by passingf_mymodule
.Parameters: - f_params : file-like object, str, None (default=None)
Path of module parameters. Pass
None
to not save- f_optimizer : file-like object, str, None (default=None)
Path of optimizer. Pass
None
to not save- f_criterion : file-like object, str, None (default=None)
Path of criterion. Pass
None
to not save- f_history : file-like object, str, None (default=None)
Path to history. Pass
None
to not save
Examples
>>> before = NeuralNetClassifier(mymodule) >>> before.save_params(f_params='model.pkl', ... f_optimizer='optimizer.pkl', ... f_history='history.json') >>> after = NeuralNetClassifier(mymodule).initialize() >>> after.load_params(f_params='model.pkl', ... f_optimizer='optimizer.pkl', ... f_history='history.json')
-
set_params
(**kwargs)[source]¶ Set the parameters of this class.
Valid parameter keys can be listed with
get_params()
.Returns: - self
-
train_step
(batch, **fit_params)[source]¶ Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
Loss function callable as required by some optimizers (and accepted by all of them): https://pytorch.org/docs/master/optim.html#optimizer-step-closure
The module is set to be in train mode (e.g. dropout is applied).
Parameters: - batch
A single batch returned by the data loader.
- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to the train_split call.
Returns: - step : dict
A dictionary
{'loss': loss, 'y_pred': y_pred}
, where the floatloss
is the result of the loss function andy_pred
the prediction generated by the PyTorch module.
-
train_step_single
(batch, **fit_params)[source]¶ Compute y_pred, loss value, and update net’s gradients.
The module is set to be in train mode (e.g. dropout is applied).
Parameters: - batch
A single batch returned by the data loader.
- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
Returns: - step : dict
A dictionary
{'loss': loss, 'y_pred': y_pred}
, where the floatloss
is the result of the loss function andy_pred
the prediction generated by the PyTorch module.
-
trim_for_prediction
()[source]¶ Remove all attributes not required for prediction.
Use this method after you finished training your net, with the goal of reducing its size. All attributes only required during training (e.g. the optimizer) are set to None. This can lead to a considerable decrease in memory footprint. It also makes it more likely that the net can be loaded with different library versions.
After calling this function, the net can only be used for prediction (e.g.
net.predict
ornet.predict_proba
) but no longer for training (e.g.net.fit(X, y)
will raise an exception).This operation is irreversible. Once the net has been trimmed for prediction, it is no longer possible to restore the original state. Morevoer, this operation mutates the net. If you need the unmodified net, create a deepcopy before trimming:
from copy import deepcopy net = NeuralNet(...) net.fit(X, y) # training finished net_original = deepcopy(net) net.trim_for_prediction() net.predict(X)
-
validation_step
(batch, **fit_params)[source]¶ Perform a forward step using batched data and return the resulting loss.
The module is set to be in evaluation mode (e.g. dropout is not applied).
Parameters: - batch
A single batch returned by the data loader.
- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
skorch.probabilistic¶
Integrate GPyTorch for Gaussian Processes
- The criterion always takes likelihood and module as input arguments
- Always optimize the negative objective function
- Need elaboration on how batching works - are distributions disjoint?
-
class
skorch.probabilistic.
ExactGPRegressor
(module, *args, likelihood=<class 'gpytorch.likelihoods.gaussian_likelihood.GaussianLikelihood'>, criterion=<class 'gpytorch.mlls.exact_marginal_log_likelihood.ExactMarginalLogLikelihood'>, batch_size=-1, **kwargs)[source]¶ Exact Gaussian Process regressor
Use this specifically if you want to perform an exact solution to the Gaussian Process. This implies that the module should by a
ExactGP
module and you cannot use batching (i.e. batch size should be -1).Parameters: - Module : gpytorch.models.ExactGP (class or instance)
The module needs to return a
MultivariateNormal
distribution.- likelihood : gpytorch.likelihoods.GaussianLikelihood (class or instance)
The likelihood used for the exact GP regressor. Usually doesn’t need to be changed.
- criterion : gpytorch.mlls.ExactMarginalLogLikelihood
The objective function to learn the posterior of of the GP regressor. Usually doesn’t need to be changed.
- optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr : float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs : int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size : int (default=-1)
Mini-batch size. For exact GPs, it must be set to -1, since the exact solution cannot deal with batching. To make use of batching, use
GPRegressor
in conjunction with a variational strategy.- iterator_train : torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid : torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split : None or callable (default=None)
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None. There is no default train split for GP regressors because random splitting is typically not desired, e.g. because there is a temporal relationship between samples.- callbacks : None, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose : int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device : str, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
Attributes: - prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_ : bool
Whether the
NeuralNet
was initialized.- module_ : torch module (instance)
The instantiated module.
- criterion_ : torch criterion (instance)
The instantiated criterion.
- callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules : list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria : list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers : list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- likelihood_: torch module (instance)
The instantiated likelihood.
Methods
check_is_fitted
([attributes])Checks whether the GP is initialized. check_training_readiness
()Check that the net is ready to train confidence_region
(X[, sigmas])Returns 2 standard deviations above and below the mean. evaluation_step
(batch[, training])Perform a forward step to produce the output used for prediction and scoring. fit
(X[, y])Initialize and fit the module. fit_loop
(X[, y, epochs])The proper fit loop. forward
(X[, training, device])Gather and concatenate the output from forward call with input data. forward_iter
(X, *args, **kwargs)Yield outputs of module forward calls on each batch of data. get_all_learnable_params
()Yield the learnable parameters of all modules get_dataset
(X[, y])Get a dataset that contains the input data and is passed to the iterator. get_iterator
(dataset[, training])Get an iterator that allows to loop over the batches of the given data. get_loss
(y_pred, y_true[, X, training])Return the loss for this batch. get_params_for
(prefix)Collect and return init parameters for an attribute. get_params_for_optimizer
(prefix, …)Collect and return init parameters for an optimizer. get_split_datasets
(X[, y])Get internal train and validation datasets. get_train_step_accumulator
()Return the train step accumulator. infer
(x, **fit_params)Perform a single inference step on a batch of data. initialize
()Initializes all of its components and returns self. initialize_callbacks
()Initializes all callbacks and save the result in the callbacks_
attribute.initialize_criterion
()Initializes the criterion. initialize_history
()Initializes the history. initialize_module
()Initializes likelihood and module. initialize_optimizer
([triggered_directly])Initialize the model optimizer. initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters load_params
([f_params, f_optimizer, …])Loads the the module’s parameters, history, and optimizer, not the whole object. notify
(method_name, **cb_kwargs)Call the callback method specified in method_name
with parameters specified incb_kwargs
.on_batch_begin
(net[, batch, training])on_epoch_begin
(net[, dataset_train, …])on_epoch_end
(net[, dataset_train, dataset_valid])on_train_begin
(net[, X, y])on_train_end
(net[, X, y])partial_fit
(X[, y, classes])Fit the module. predict
(X[, return_std, return_cov])Returns the predicted mean and optionally standard deviation. predict_proba
(X)Return the output of the module’s forward method as a numpy array. run_single_epoch
(iterator, training, prefix, …)Compute a single epoch of train or validation. sample
(X, n_samples[, axis])Return samples conditioned on input data. save_params
([f_params, f_optimizer, …])Saves the module’s parameters, history, and optimizer, not the whole object. set_params
(**kwargs)Set the parameters of this class. train_step
(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step. train_step_single
(batch, **fit_params)Compute y_pred, loss value, and update net’s gradients. trim_for_prediction
()Remove all attributes not required for prediction. validation_step
(batch, **fit_params)Perform a forward step using batched data and return the resulting loss. check_data get_default_callbacks get_params initialize_virtual_params on_batch_end on_grad_computed -
fit
(X, y=None, **fit_params)[source]¶ Initialize and fit the module.
If the module was already initialized, by calling fit, the module will be re-initialized (unless
warm_start
is True).Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- y : target data, compatible with skorch.dataset.Dataset
The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.- **fit_params : dict
Additional parameters passed to the
forward
method of the module and to theself.train_split
call.
-
class
skorch.probabilistic.
GPRegressor
(module, *args, likelihood=<class 'gpytorch.likelihoods.gaussian_likelihood.GaussianLikelihood'>, criterion=<class 'gpytorch.mlls.variational_elbo.VariationalELBO'>, **kwargs)[source]¶ Gaussian Process regressor
Use this for variational and approximate Gaussian process regression. This implies that the module should by a
ApproximateGP
module.Parameters: - Module : gpytorch.models.ApproximateGP (class or instance)
The GPyTorch module; in contrast to exact GP, the return distribution does not need to be Gaussian.
- likelihood : gpytorch.likelihoods.GaussianLikelihood (class or instance)
The likelihood used for the exact GP regressor. Usually doesn’t need to be changed.
- criterion : gpytorch.mlls.VariationalELBO
The objective function to learn the approximate posterior of of the GP regressor.
- optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr : float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs : int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size : int (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.- iterator_train : torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid : torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split : None or callable (default=None)
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None. There is no default train split for GP regressors because random splitting is typically not desired, e.g. because there is a temporal relationship between samples.- callbacks : None, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose : int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device : str, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
Attributes: - prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_ : bool
Whether the
NeuralNet
was initialized.- module_ : torch module (instance)
The instantiated module.
- criterion_ : torch criterion (instance)
The instantiated criterion.
- callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules : list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria : list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers : list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- likelihood_: torch module (instance)
The instantiated likelihood.
Methods
check_is_fitted
([attributes])Checks whether the GP is initialized. check_training_readiness
()Check that the net is ready to train confidence_region
(X[, sigmas])Returns 2 standard deviations above and below the mean. evaluation_step
(batch[, training])Perform a forward step to produce the output used for prediction and scoring. fit
(X[, y])Initialize and fit the module. fit_loop
(X[, y, epochs])The proper fit loop. forward
(X[, training, device])Gather and concatenate the output from forward call with input data. forward_iter
(X, *args, **kwargs)Yield outputs of module forward calls on each batch of data. get_all_learnable_params
()Yield the learnable parameters of all modules get_dataset
(X[, y])Get a dataset that contains the input data and is passed to the iterator. get_iterator
(dataset[, training])Get an iterator that allows to loop over the batches of the given data. get_loss
(y_pred, y_true[, X, training])Return the loss for this batch. get_params_for
(prefix)Collect and return init parameters for an attribute. get_params_for_optimizer
(prefix, …)Collect and return init parameters for an optimizer. get_split_datasets
(X[, y])Get internal train and validation datasets. get_train_step_accumulator
()Return the train step accumulator. infer
(x, **fit_params)Perform a single inference step on a batch of data. initialize
()Initializes all of its components and returns self. initialize_callbacks
()Initializes all callbacks and save the result in the callbacks_
attribute.initialize_criterion
()Initializes the criterion. initialize_history
()Initializes the history. initialize_module
()Initializes likelihood and module. initialize_optimizer
([triggered_directly])Initialize the model optimizer. initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters load_params
([f_params, f_optimizer, …])Loads the the module’s parameters, history, and optimizer, not the whole object. notify
(method_name, **cb_kwargs)Call the callback method specified in method_name
with parameters specified incb_kwargs
.on_batch_begin
(net[, batch, training])on_epoch_begin
(net[, dataset_train, …])on_epoch_end
(net[, dataset_train, dataset_valid])on_train_begin
(net[, X, y])on_train_end
(net[, X, y])partial_fit
(X[, y, classes])Fit the module. predict
(X[, return_std, return_cov])Returns the predicted mean and optionally standard deviation. predict_proba
(X)Return the output of the module’s forward method as a numpy array. run_single_epoch
(iterator, training, prefix, …)Compute a single epoch of train or validation. sample
(X, n_samples[, axis])Return samples conditioned on input data. save_params
([f_params, f_optimizer, …])Saves the module’s parameters, history, and optimizer, not the whole object. set_params
(**kwargs)Set the parameters of this class. train_step
(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step. train_step_single
(batch, **fit_params)Compute y_pred, loss value, and update net’s gradients. trim_for_prediction
()Remove all attributes not required for prediction. validation_step
(batch, **fit_params)Perform a forward step using batched data and return the resulting loss. check_data get_default_callbacks get_params initialize_virtual_params on_batch_end on_grad_computed
-
class
skorch.probabilistic.
GPBinaryClassifier
(module, *args, likelihood=<class 'gpytorch.likelihoods.bernoulli_likelihood.BernoulliLikelihood'>, criterion=<class 'gpytorch.mlls.variational_elbo.VariationalELBO'>, train_split=<skorch.dataset.ValidSplit object>, threshold=0.5, **kwargs)[source]¶ Gaussian Process binary classifier
Use this for variational and approximate Gaussian process binary classification. This implies that the module should by a
ApproximateGP
module.Parameters: - Module : gpytorch.models.ApproximateGP (class or instance)
The GPyTorch module; in contrast to exact GP, the return distribution does not need to be Gaussian.
- likelihood : gpytorch.likelihoods.BernoulliLikelihood (class or instance)
The likelihood used for the exact GP binary classification. Usually doesn’t need to be changed.
- criterion : gpytorch.mlls.VariationalELBO
The objective function to learn the approximate posterior of of the GP binary classification.
- optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr : float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs : int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size : int (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.- iterator_train : torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid : torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split : None or callable (default=skorch.dataset.ValidSplit(5))
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None.- callbacks : None, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose : int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device : str, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
Attributes: - prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_ : bool
Whether the
NeuralNet
was initialized.- module_ : torch module (instance)
The instantiated module.
- criterion_ : torch criterion (instance)
The instantiated criterion.
- callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules : list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria : list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers : list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- likelihood_: torch module (instance)
The instantiated likelihood.
Methods
check_data
(X, y)check_is_fitted
([attributes])Checks whether the GP is initialized. check_training_readiness
()Check that the net is ready to train confidence_region
(X[, sigmas])Returns 2 standard deviations above and below the mean. evaluation_step
(batch[, training])Perform a forward step to produce the output used for prediction and scoring. fit
(X[, y])Initialize and fit the module. fit_loop
(X[, y, epochs])The proper fit loop. forward
(X[, training, device])Gather and concatenate the output from forward call with input data. forward_iter
(X, *args, **kwargs)Yield outputs of module forward calls on each batch of data. get_all_learnable_params
()Yield the learnable parameters of all modules get_dataset
(X[, y])Get a dataset that contains the input data and is passed to the iterator. get_iterator
(dataset[, training])Get an iterator that allows to loop over the batches of the given data. get_loss
(y_pred, y_true[, X, training])Return the loss for this batch. get_params_for
(prefix)Collect and return init parameters for an attribute. get_params_for_optimizer
(prefix, …)Collect and return init parameters for an optimizer. get_split_datasets
(X[, y])Get internal train and validation datasets. get_train_step_accumulator
()Return the train step accumulator. infer
(x, **fit_params)Perform a single inference step on a batch of data. initialize
()Initializes all of its components and returns self. initialize_callbacks
()Initializes all callbacks and save the result in the callbacks_
attribute.initialize_criterion
()Initializes the criterion. initialize_history
()Initializes the history. initialize_module
()Initializes likelihood and module. initialize_optimizer
([triggered_directly])Initialize the model optimizer. initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters load_params
([f_params, f_optimizer, …])Loads the the module’s parameters, history, and optimizer, not the whole object. notify
(method_name, **cb_kwargs)Call the callback method specified in method_name
with parameters specified incb_kwargs
.on_batch_begin
(net[, batch, training])on_epoch_begin
(net[, dataset_train, …])on_epoch_end
(net[, dataset_train, dataset_valid])on_train_begin
(net[, X, y])on_train_end
(net[, X, y])partial_fit
(X[, y, classes])Fit the module. predict
(X)Return class labels for samples in X. predict_proba
(X)Return probability estimates for the samples. run_single_epoch
(iterator, training, prefix, …)Compute a single epoch of train or validation. sample
(X, n_samples[, axis])Return samples conditioned on input data. save_params
([f_params, f_optimizer, …])Saves the module’s parameters, history, and optimizer, not the whole object. set_params
(**kwargs)Set the parameters of this class. train_step
(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step. train_step_single
(batch, **fit_params)Compute y_pred, loss value, and update net’s gradients. trim_for_prediction
()Remove all attributes not required for prediction. validation_step
(batch, **fit_params)Perform a forward step using batched data and return the resulting loss. get_default_callbacks get_params initialize_virtual_params on_batch_end on_grad_computed -
predict
(X)[source]¶ Return class labels for samples in X.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()
instead.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Returns: - y_pred : numpy ndarray
Predicted target values for
X
.
-
predict_proba
(X)[source]¶ Return probability estimates for the samples.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()
instead.Parameters: - X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.
Returns: - y_proba : numpy ndarray
Probabilities for the samples, with the first column corresponding to class 0 and the second to class 1.
skorch.regressor¶
NeuralNet subclasses for regression tasks.
-
class
skorch.regressor.
NeuralNetRegressor
(module, *args, criterion=<class 'torch.nn.modules.loss.MSELoss'>, **kwargs)[source]¶ NeuralNet for regression tasks
Use this specifically if you have a standard regression task, with input data X and target y. y must be 2d.
In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:
>>> net = NeuralNet( ... ..., ... optimizer=torch.optimizer.SGD, ... optimizer__momentum=0.95, ...)
This way, when
optimizer
is initialized,NeuralNet
will take care of setting themomentum
parameter to 0.95.(Note that the double underscore notation in
optimizer__momentum
means that the parametermomentum
should be set on the objectoptimizer
. This is the same semantic as used by sklearn.)Furthermore, this allows to change those parameters later:
net.set_params(optimizer__momentum=0.99)
This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.
By default an
EpochTimer
,BatchScoring
(for both training and validation datasets), andPrintLog
callbacks are installed for the user’s convenience.Parameters: - module : torch module (class or instance)
A PyTorch
Module
. In general, the uninstantiated class should be passed, although instantiated modules will also work.- criterion : torch criterion (class, default=torch.nn.MSELoss)
Mean squared error loss.
- optimizer : torch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lr : float (default=0.01)
Learning rate passed to the optimizer. You may use
lr
instead of usingoptimizer__lr
, which would result in the same outcome.- max_epochs : int (default=10)
The number of epochs to train for each
fit
call. Note that you may keyboard-interrupt training at any time.- batch_size : int (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_size
anditerator_test__batch_size
, which would result in the same outcome. Ifbatch_size
is -1, a single batch with all the data will be used during training and validation.- iterator_train : torch DataLoader
The default PyTorch
DataLoader
used for training data.- iterator_valid : torch DataLoader
The default PyTorch
DataLoader
used for validation and test data, i.e. during inference.- dataset : torch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader
. It has to implement the__len__
and__getitem__
methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDataset
class and define additional arguments to X and y by prefixing them withdataset__
. It is also possible to pass an initialzedDataset
, in which case no additional arguments may be passed.- train_split : None or callable (default=skorch.dataset.ValidSplit(5))
If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple
dataset_train, dataset_valid
. The validation data may be None.- callbacks : None, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None
, only use default callbacks, those returned byget_default_callbacks
.If
callbacks="disable"
, disable all callbacks, i.e. do not run any of the callbacks.If
callbacks
is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback
.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1
. Alternatively, a tuple(name, callback)
can be passed, wherename
should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log'
, usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])
).- predict_nonlinearity : callable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLoss
and sigmoid forBCEWithLogitsLoss
). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()
should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearity
transforms this output into probabilities.The nonlinearity is applied only when calling
predict()
orpredict_proba()
but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_start : bool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verbose : int (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- device : str, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
Attributes: - prefixes_ : list of str
Contains the prefixes to special parameters. E.g., since there is the
'module'
prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95)
.- cuda_dependent_attributes_ : list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNet
trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_ : bool
Whether the
NeuralNet
was initialized.- module_ : torch module (instance)
The instantiated module.
- criterion_ : torch criterion (instance)
The instantiated criterion.
- callbacks_ : list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _modules : list of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criteria : list of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizers : list of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
Methods
check_data
(X, y)check_is_fitted
([attributes])Checks whether the net is initialized check_training_readiness
()Check that the net is ready to train evaluation_step
(batch[, training])Perform a forward step to produce the output used for prediction and scoring. fit
(X, y, **fit_params)See NeuralNet.fit
.fit_loop
(X[, y, epochs])The proper fit loop. forward
(X[, training, device])Gather and concatenate the output from forward call with input data. forward_iter
(X[, training, device])Yield outputs of module forward calls on each batch of data. get_all_learnable_params
()Yield the learnable parameters of all modules get_dataset
(X[, y])Get a dataset that contains the input data and is passed to the iterator. get_iterator
(dataset[, training])Get an iterator that allows to loop over the batches of the given data. get_loss
(y_pred, y_true[, X, training])Return the loss for this batch. get_params_for
(prefix)Collect and return init parameters for an attribute. get_params_for_optimizer
(prefix, …)Collect and return init parameters for an optimizer. get_split_datasets
(X[, y])Get internal train and validation datasets. get_train_step_accumulator
()Return the train step accumulator. infer
(x, **fit_params)Perform a single inference step on a batch of data. initialize
()Initializes all of its components and returns self. initialize_callbacks
()Initializes all callbacks and save the result in the callbacks_
attribute.initialize_criterion
()Initializes the criterion. initialize_history
()Initializes the history. initialize_module
()Initializes the module. initialize_optimizer
([triggered_directly])Initialize the model optimizer. initialized_instance
(instance_or_cls, kwargs)Return an instance initialized with the given parameters load_params
([f_params, f_optimizer, …])Loads the the module’s parameters, history, and optimizer, not the whole object. notify
(method_name, **cb_kwargs)Call the callback method specified in method_name
with parameters specified incb_kwargs
.on_batch_begin
(net[, batch, training])on_epoch_begin
(net[, dataset_train, …])on_epoch_end
(net[, dataset_train, dataset_valid])on_train_begin
(net[, X, y])on_train_end
(net[, X, y])partial_fit
(X[, y, classes])Fit the module. predict
(X)Where applicable, return class labels for samples in X. predict_proba
(X)Return the output of the module’s forward method as a numpy array. run_single_epoch
(iterator, training, prefix, …)Compute a single epoch of train or validation. save_params
([f_params, f_optimizer, …])Saves the module’s parameters, history, and optimizer, not the whole object. score
(X, y[, sample_weight])Return the coefficient of determination of the prediction. set_params
(**kwargs)Set the parameters of this class. train_step
(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step. train_step_single
(batch, **fit_params)Compute y_pred, loss value, and update net’s gradients. trim_for_prediction
()Remove all attributes not required for prediction. validation_step
(batch, **fit_params)Perform a forward step using batched data and return the resulting loss. get_default_callbacks get_params initialize_virtual_params on_batch_end on_grad_computed
skorch.scoring¶
Custom scoring functions
-
skorch.scoring.
loss_scoring
(net, X, y=None, sample_weight=None)[source]¶ Calculate score using the criterion of the net
Use the exact same logic as during model training to calculate the score.
This function can be used to implement the
score
method for aNeuralNet
through sub-classing. This is useful, for example, when combining skorch models with sklearn objects that rely on the model’sscore
method. For example:>>> class ScoredNet(skorch.NeuralNetClassifier): ... def score(self, X, y=None): ... return loss_scoring(self, X, y)
Parameters: - net : skorch.NeuralNet
A fitted Skorch
NeuralNet
object.- X : input data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
- numpy arrays
- torch tensors
- pandas DataFrame or Series
- scipy sparse CSR matrices
- a dictionary of the former three
- a list/tuple of the former three
- a Dataset
If this doesn’t work with your data, you have to pass a
Dataset
that can deal with the data.- y : target data, compatible with skorch.dataset.Dataset
The same data types as for
X
are supported. If your X is a Dataset that contains the target,y
may be set to None.- sample_weight : array-like of shape (n_samples,)
Sample weights.
Returns: - loss_value : float32 or np.ndarray
Return type depends on
net.criterion_.reduction
, and will be a float if reduction is'sum'
or'mean'
. If reduction is'none'
then this function returns anp.ndarray
object.
skorch.toy¶
Contains toy functions and classes for quick prototyping and testing.
-
class
skorch.toy.
MLPModule
(input_units=20, output_units=2, hidden_units=10, num_hidden=1, nonlin=ReLU(), output_nonlin=None, dropout=0, squeeze_output=False)[source]¶ A simple multi-layer perceptron module.
This can be adapted for usage in different contexts, e.g. binary and multi-class classification, regression, etc.
Parameters: - input_units : int (default=20)
Number of input units.
- output_units : int (default=2)
Number of output units.
- hidden_units : int (default=10)
Number of units in hidden layers.
- num_hidden : int (default=1)
Number of hidden layers.
- nonlin : torch.nn.Module instance (default=torch.nn.ReLU())
Non-linearity to apply after hidden layers.
- output_nonlin : torch.nn.Module instance or None (default=None)
Non-linearity to apply after last layer, if any.
- dropout : float (default=0)
Dropout rate. Dropout is applied between layers.
- squeeze_output : bool (default=False)
Whether to squeeze output. Squeezing can be helpful if you wish your output to be 1-dimensional (e.g. for NeuralNetBinaryClassifier).
Methods
add_module
(name, module)Adds a child module to the current module. apply
(fn, None])Applies fn
recursively to every submodule (as returned by.children()
) as well as self.bfloat16
()Casts all floating point parameters and buffers to bfloat16
datatype.buffers
(recurse)Returns an iterator over module buffers. children
()Returns an iterator over immediate children modules. cpu
()Moves all model parameters and buffers to the CPU. cuda
(device, torch.device, None] = None)Moves all model parameters and buffers to the GPU. double
()Casts all floating point parameters and buffers to double
datatype.eval
()Sets the module in evaluation mode. extra_repr
()Set the extra representation of the module float
()Casts all floating point parameters and buffers to float
datatype.forward
(X)Defines the computation performed at every call. get_buffer
(target)Returns the buffer given by target
if it exists, otherwise throws an error.get_extra_state
()Returns any extra state to include in the module’s state_dict. get_parameter
(target)Returns the parameter given by target
if it exists, otherwise throws an error.get_submodule
(target)Returns the submodule given by target
if it exists, otherwise throws an error.half
()Casts all floating point parameters and buffers to half
datatype.ipu
(device, torch.device, None] = None)Moves all model parameters and buffers to the IPU. load_state_dict
(state_dict, Any], strict)Copies parameters and buffers from state_dict
into this module and its descendants.modules
()Returns an iterator over all modules in the network. named_buffers
(prefix, recurse)Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. named_children
()Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself. named_modules
(memo, prefix, remove_duplicate)Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. named_parameters
(prefix, recurse)Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. parameters
(recurse)Returns an iterator over module parameters. register_backward_hook
(hook, …)Registers a backward hook on the module. register_buffer
(name, tensor, persistent)Adds a buffer to the module. register_forward_hook
(hook, None])Registers a forward hook on the module. register_forward_pre_hook
(hook, None])Registers a forward pre-hook on the module. register_full_backward_hook
(hook, …)Registers a backward hook on the module. register_load_state_dict_post_hook
(hook)Registers a post hook to be run after module’s load_state_dict
is called.register_module
(name, module)Alias for add_module()
.register_parameter
(name, param)Adds a parameter to the module. requires_grad_
(requires_grad)Change if autograd should record operations on parameters in this module. reset_params
()(Re)set all parameters. set_extra_state
(state)This function is called from load_state_dict()
to handle any extra state found within the state_dict.share_memory
()See torch.Tensor.share_memory_()
state_dict
(*args[, destination, prefix, …])Returns a dictionary containing a whole state of the module. to
(*args, **kwargs)Moves and/or casts the parameters and buffers. to_empty
(*, device, torch.device])Moves the parameters and buffers to the specified device without copying storage. train
(mode)Sets the module in training mode. type
(dst_type, str])Casts all parameters and buffers to dst_type
.xpu
(device, torch.device, None] = None)Moves all model parameters and buffers to the XPU. zero_grad
(set_to_none)Sets gradients of all model parameters to zero. __call__ -
forward
(X)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
skorch.toy.
make_binary_classifier
(squeeze_output=True, **kwargs)[source]¶ Return a multi-layer perceptron to be used with NeuralNetBinaryClassifier.
Parameters: - input_units : int (default=20)
Number of input units.
- output_units : int (default=2)
Number of output units.
- hidden_units : int (default=10)
Number of units in hidden layers.
- num_hidden : int (default=1)
Number of hidden layers.
- nonlin : torch.nn.Module instance (default=torch.nn.ReLU())
Non-linearity to apply after hidden layers.
- dropout : float (default=0)
Dropout rate. Dropout is applied between layers.
-
skorch.toy.
make_classifier
(output_nonlin=Softmax(dim=-1), **kwargs)[source]¶ Return a multi-layer perceptron to be used with NeuralNetClassifier.
Parameters: - input_units : int (default=20)
Number of input units.
- output_units : int (default=2)
Number of output units.
- hidden_units : int (default=10)
Number of units in hidden layers.
- num_hidden : int (default=1)
Number of hidden layers.
- nonlin : torch.nn.Module instance (default=torch.nn.ReLU())
Non-linearity to apply after hidden layers.
- dropout : float (default=0)
Dropout rate. Dropout is applied between layers.
-
skorch.toy.
make_regressor
(output_units=1, **kwargs)[source]¶ Return a multi-layer perceptron to be used with NeuralNetRegressor.
Parameters: - input_units : int (default=20)
Number of input units.
- output_units : int (default=1)
Number of output units.
- hidden_units : int (default=10)
Number of units in hidden layers.
- num_hidden : int (default=1)
Number of hidden layers.
- nonlin : torch.nn.Module instance (default=torch.nn.ReLU())
Non-linearity to apply after hidden layers.
- dropout : float (default=0)
Dropout rate. Dropout is applied between layers.
skorch.utils¶
skorch utilities.
Should not have any dependency on other skorch packages.
-
class
skorch.utils.
FirstStepAccumulator
[source]¶ Store and retrieve the train step data.
This class simply stores the first step value and returns it.
For most uses,
skorch.utils.FirstStepAccumulator
is what you want, since the optimizer calls the train step exactly once. However, some optimizerss such as LBFGSs make more than one call. If in that case, you don’t want the first value to be returned (but instead, say, the last value), implement your own accumulator and make sure it is returned byNeuralNet.get_train_step_accumulator
method.Methods
get_step
()Return the stored step. store_step
(step)Store the first step.
-
class
skorch.utils.
TeeGenerator
(gen)[source]¶ Stores a generator and calls
tee
on it to create new generators whenTeeGenerator
is iterated over to let you iterate over the given generator more than once.
-
skorch.utils.
check_indexing
(data)[source]¶ Perform a check how incoming data should be indexed and return an appropriate indexing function with signature f(data, index).
This is useful for determining upfront how data should be indexed instead of doing it repeatedly for each batch, thus saving some time.
-
skorch.utils.
check_is_fitted
(estimator, attributes=None, msg=None, all_or_any=<built-in function all>)[source]¶ Checks whether the net is initialized.
Note: This calls
sklearn.utils.validation.check_is_fitted
under the hood, using exactly the same arguments and logic. The only difference is that this function has an adapted error message and raises askorch.exception.NotInitializedError
instead of ansklearn.exceptions.NotFittedError
.
-
skorch.utils.
data_from_dataset
(dataset, X_indexing=None, y_indexing=None)[source]¶ Try to access X and y attribute from dataset.
Also works when dataset is a subset.
Parameters: - dataset : skorch.dataset.Dataset or torch.utils.data.Subset
The incoming dataset should be a
skorch.dataset.Dataset
or atorch.utils.data.Subset
of askorch.dataset.Dataset
.- X_indexing : function/callable or None (default=None)
If not None, use this function for indexing into the X data. If None, try to automatically determine how to index data.
- y_indexing : function/callable or None (default=None)
If not None, use this function for indexing into the y data. If None, try to automatically determine how to index data.
-
skorch.utils.
duplicate_items
(*collections)[source]¶ Search for duplicate items in all collections.
Examples
>>> duplicate_items([1, 2], [3]) set() >>> duplicate_items({1: 'a', 2: 'a'}) set() >>> duplicate_items(['a', 'b', 'a']) {'a'} >>> duplicate_items([1, 2], {3: 'hi', 4: 'ha'}, (2, 3)) {2, 3}
-
skorch.utils.
freeze_parameter
(param)[source]¶ Convenience function to freeze a passed torch parameter. Used by
skorch.callbacks.Freezer
-
skorch.utils.
get_dim
(y)[source]¶ Return the number of dimensions of a torch tensor or numpy array-like object.
-
skorch.utils.
get_map_location
(target_device, fallback_device='cpu')[source]¶ Determine the location to map loaded data (e.g., weights) for a given target device (e.g. ‘cuda’).
-
skorch.utils.
is_skorch_dataset
(ds)[source]¶ Checks if the supplied dataset is an instance of
skorch.dataset.Dataset
even when it is nested insidetorch.util.data.Subset
.
-
skorch.utils.
multi_indexing
(data, i, indexing=None)[source]¶ Perform indexing on multiple data structures.
Currently supported data types:
- numpy arrays
- torch tensors
- pandas NDFrame
- a dictionary of the former three
- a list/tuple of the former three
i
can be an integer or a slice.Parameters: - data
Data of a type mentioned above.
- i : int or slice
Slicing index.
- indexing : function/callable or None (default=None)
If not None, use this function for indexing into the data. If None, try to automatically determine how to index data.
Examples
>>> multi_indexing(np.asarray([1, 2, 3]), 0) 1
>>> multi_indexing(np.asarray([1, 2, 3]), np.s_[:2]) array([1, 2])
>>> multi_indexing(torch.arange(0, 4), np.s_[1:3]) tensor([ 1., 2.])
>>> multi_indexing([[1, 2, 3], [4, 5, 6]], np.s_[:2]) [[1, 2], [4, 5]]
>>> multi_indexing({'a': [1, 2, 3], 'b': [4, 5, 6]}, np.s_[-2:]) {'a': [2, 3], 'b': [5, 6]}
>>> multi_indexing(pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}), [1, 2]) a b 1 2 5 2 3 6
-
skorch.utils.
noop
(*args, **kwargs)[source]¶ No-op function that does nothing and returns
None
.This is useful for defining scoring callbacks that do not need a target extractor.
-
skorch.utils.
params_for
(prefix, kwargs)[source]¶ Extract parameters that belong to a given sklearn module prefix from
kwargs
. This is useful to obtain parameters that belong to a submodule.Examples
>>> kwargs = {'encoder__a': 3, 'encoder__b': 4, 'decoder__a': 5} >>> params_for('encoder', kwargs) {'a': 3, 'b': 4}
-
skorch.utils.
to_device
(X, device)[source]¶ Generic function to modify the device type of the tensor(s) or module.
PyTorch distribution objects are left untouched, since they don’t support an API to move between devices.
Parameters: - X : input data
Deals with X being a:
- torch tensor
- tuple of torch tensors
- dict of torch tensors
- PackSequence instance
- torch.nn.Module
- device : str, torch.device
The compute device to be used. If device=None, return the input unmodified
-
skorch.utils.
to_numpy
(X)[source]¶ Generic function to convert a pytorch tensor to numpy.
This function tries to unpack the tensor(s) from supported data structures (e.g., dicts, lists, etc.) but doesn’t go beyond.
Returns X when it already is a numpy array.
-
skorch.utils.
to_tensor
(X, device, accept_sparse=False)[source]¶ Turn input data to torch tensor.
Parameters: - X : input data
- Handles the cases:
- PackedSequence
- numpy array
- torch Tensor
- scipy sparse CSR matrix
- list or tuple of one of the former
- dict with values of one of the former
- device : str, torch.device
The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.
- accept_sparse : bool (default=False)
Whether to accept scipy sparse matrices as input. If False, passing a sparse matrix raises an error. If True, it is converted to a torch COO tensor.
Returns: - output : torch Tensor