skorch.net

Neural net classes.

class skorch.net.NeuralNet(module, criterion, optimizer=<class 'torch.optim.sgd.SGD'>, lr=0.01, max_epochs=10, batch_size=128, iterator_train=<class 'torch.utils.data.dataloader.DataLoader'>, iterator_valid=<class 'torch.utils.data.dataloader.DataLoader'>, dataset=<class 'skorch.dataset.Dataset'>, train_split=<skorch.dataset.CVSplit object>, callbacks=None, warm_start=False, verbose=1, device='cpu', **kwargs)[source]

NeuralNet base class.

The base class covers more generic cases. Depending on your use case, you might want to use NeuralNetClassifier or NeuralNetRegressor.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are installed for the user’s convenience.

Parameters:
module : torch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criterion : torch criterion (class)

The uninitialized criterion (loss) used to optimize the module.

optimizer : torch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lr : float (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochs : int (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_size : int (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_train : torch DataLoader

The default PyTorch DataLoader used for training data.

iterator_valid : torch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

dataset : torch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_split : None or callable (default=skorch.dataset.CVSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacks : None or list of Callback instances (default=None)

More callbacks, in addition to those returned by get_default_callbacks. Each callback should inherit from Callback. If not None, a list of callbacks is expected where the callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

warm_start : bool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verbose : int (default=1)

Control the verbosity level.

device : str, torch.device (default=’cpu’)

The compute device to be used. If set to ‘cuda’, data in torch tensors will be pushed to cuda tensors before being sent to the module.

Attributes:
prefixes_ : list of str

Contains the prefixes to special parameters. E.g., since there is the 'module' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95).

cuda_dependent_attributes_ : list of str

Contains a list of all attribute prefixes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_ : bool

Whether the NeuralNet was initialized.

module_ : torch module (instance)

The instantiated module.

criterion_ : torch criterion (instance)

The instantiated criterion.

callbacks_ : list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

Methods

check_is_fitted([attributes]) Checks whether the net is initialized
evaluation_step(Xi[, training]) Perform a forward step to produce the output used for prediction and scoring.
fit(X[, y]) Initialize and fit the module.
fit_loop(X[, y, epochs]) The proper fit loop.
forward(X[, training, device]) Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device]) Yield outputs of module forward calls on each batch of data.
get_dataset(X[, y]) Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training]) Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training]) Return the loss for this batch.
get_split_datasets(X[, y]) Get internal train and validation datasets.
get_train_step_accumulator() Return the train step accumulator.
infer(x, **fit_params) Perform a single inference step on a batch of data.
initialize() Initializes all components of the NeuralNet and returns self.
initialize_callbacks() Initializes all callbacks and save the result in the callbacks_ attribute.
initialize_criterion() Initializes the criterion.
initialize_history() Initializes the history.
initialize_module() Initializes the module.
initialize_optimizer([triggered_directly]) Initialize the model optimizer.
load_params([f_params, f_optimizer, …]) Loads the the module’s parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs) Call the callback method specified in method_name with parameters specified in cb_kwargs.
on_batch_begin(net[, Xi, yi, training])
on_epoch_begin(net[, dataset_train, …])
on_epoch_end(net[, dataset_train, dataset_valid])
on_train_begin(net[, X, y])
on_train_end(net[, X, y])
partial_fit(X[, y, classes]) Fit the module.
predict(X) Where applicable, return class labels for samples in X.
predict_proba(X) Return the output of the module’s forward method as a numpy array.
save_params([f_params, f_optimizer, f_history]) Saves the module’s parameters, history, and optimizer, not the whole object.
set_params(**kwargs) Set the parameters of this class.
train_step(Xi, yi, **fit_params) Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(Xi, yi, **fit_params) Compute y_pred, loss value, and update net’s gradients.
validation_step(Xi, yi, **fit_params) Perform a forward step using batched data and return the resulting loss.
check_data  
get_default_callbacks  
get_params  
initialize_virtual_params  
on_batch_end  
on_grad_computed  
check_is_fitted(attributes=None, *args, **kwargs)[source]

Checks whether the net is initialized

Parameters:
attributes : iterable of str or None (default=None)

All the attributes that are strictly required of a fitted net. By default, this is the module_ attribute.

Other arguments as in
``sklearn.utils.validation.check_is_fitted``.
Raises:
skorch.exceptions.NotInitializedError

When the given attributes are not present.

evaluation_step(Xi, training=False)[source]

Perform a forward step to produce the output used for prediction and scoring.

Therefore the module is set to evaluation mode by default beforehand which can be overridden to re-enable features like dropout by setting training=True.

fit(X, y=None, **fit_params)[source]

Initialize and fit the module.

If the module was already initialized, by calling fit, the module will be re-initialized (unless warm_start is True).

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

fit_loop(X, y=None, epochs=None, **fit_params)[source]

The proper fit loop.

Contains the logic of what actually happens during the fit loop.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

epochs : int or None (default=None)

If int, train for this number of epochs; if None, use self.max_epochs.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

forward(X, training=False, device='cpu')[source]

Gather and concatenate the output from forward call with input data.

The outputs from self.module_.forward are gathered on the compute device specified by device and then concatenated using PyTorch cat(). If multiple outputs are returned by self.module_.forward, each one of them must be able to be concatenated this way.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether to set the module to train mode or not.

device : string (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Returns:
y_infer : torch tensor

The result from the forward step.

forward_iter(X, training=False, device='cpu')[source]

Yield outputs of module forward calls on each batch of data. The storage device of the yielded tensors is determined by the device parameter.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether to set the module to train mode or not.

device : string (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Yields:
yp : torch tensor

Result from a forward call on an individual batch.

get_dataset(X, y=None)[source]

Get a dataset that contains the input data and is passed to the iterator.

Override this if you want to initialize your dataset differently.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

Returns:
dataset

The initialized dataset.

get_iterator(dataset, training=False)[source]

Get an iterator that allows to loop over the batches of the given data.

If self.iterator_train__batch_size and/or self.iterator_test__batch_size are not set, use self.batch_size instead.

Parameters:
dataset : torch Dataset (default=skorch.dataset.Dataset)

Usually, self.dataset, initialized with the corresponding data, is passed to get_iterator.

training : bool (default=False)

Whether to use iterator_train or iterator_test.

Returns:
iterator

An instantiated iterator that allows to loop over the mini-batches.

get_loss(y_pred, y_true, X=None, training=False)[source]

Return the loss for this batch.

Parameters:
y_pred : torch tensor

Predicted target values

y_true : torch tensor

True target values.

X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

training : bool (default=False)

Whether train mode should be used or not.

get_split_datasets(X, y=None, **fit_params)[source]

Get internal train and validation datasets.

The validation dataset can be None if self.train_split is set to None; then internal validation will be skipped.

Override this if you want to change how the net splits incoming data into train and validation part.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_params : dict

Additional parameters passed to the self.train_split call.

Returns:
dataset_train

The initialized training dataset.

dataset_valid

The initialized validation dataset or None

get_train_step_accumulator()[source]

Return the train step accumulator.

By default, the accumulator stores and retrieves the first value from the optimizer call. Most optimizers make only one call, so first value is at the same time the only value.

In case of some optimizers, e.g. LBFGS, train_step_calc_gradient is called multiple times, as the loss function is evaluated multiple times per optimizer call. If you don’t want to return the first value in that case, override this method to return your custom accumulator.

infer(x, **fit_params)[source]

Perform a single inference step on a batch of data.

Parameters:
x : input data

A batch of the input data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

initialize()[source]

Initializes all components of the NeuralNet and returns self.

initialize_callbacks()[source]

Initializes all callbacks and save the result in the callbacks_ attribute.

Both default_callbacks and callbacks are used (in that order). Callbacks may either be initialized or not, and if they don’t have a name, the name is inferred from the class name. The initialize method is called on all callbacks.

The final result will be a list of tuples, where each tuple consists of a name and an initialized callback. If names are not unique, a ValueError is raised.

initialize_criterion()[source]

Initializes the criterion.

initialize_history()[source]

Initializes the history.

initialize_module()[source]

Initializes the module.

Note that if the module has learned parameters, those will be reset.

initialize_optimizer(triggered_directly=True)[source]

Initialize the model optimizer. If self.optimizer__lr is not set, use self.lr instead.

Parameters:
triggered_directly : bool (default=True)

Only relevant when optimizer is re-initialized. Initialization of the optimizer can be triggered directly (e.g. when lr was changed) or indirectly (e.g. when the module was re-initialized). If and only if the former happens, the user should receive a message informing them about the parameters that caused the re-initialization.

load_params(f_params=None, f_optimizer=None, f_history=None, checkpoint=None)[source]

Loads the the module’s parameters, history, and optimizer, not the whole object.

To save and load the whole object, use pickle.

f_params and f_optimizer uses PyTorchs’ save().

Parameters:
f_params : file-like object, str, None (default=None)

Path of module parameters. Pass None to not load.

f_optimizer : file-like object, str, None (default=None)

Path of optimizer. Pass None to not load.

f_history : file-like object, str, None (default=None)

Path to history. Pass None to not load.

checkpoint : Checkpoint, None (default=None)

Checkpoint to load params from. If a checkpoint and a f_* path is passed in, the f_* will be loaded. Pass None to not load.

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
>>>                    f_optimizer='optimizer.pkl',
>>>                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
>>>                   f_optimizer='optimizer.pkl',
>>>                   f_history='history.json')
notify(method_name, **cb_kwargs)[source]

Call the callback method specified in method_name with parameters specified in cb_kwargs.

Method names can be one of: * on_train_begin * on_train_end * on_epoch_begin * on_epoch_end * on_batch_begin * on_batch_end

partial_fit(X, y=None, classes=None, **fit_params)[source]

Fit the module.

If the module is initialized, it is not re-initialized, which means that this method should be used if you want to continue training a model (warm start).

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

y : target data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

classes : array, sahpe (n_classes,)

Solely for sklearn compatibility, currently unused.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_pred : numpy ndarray
predict_proba(X)[source]

Return the output of the module’s forward method as a numpy array.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters:
X : input data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays
  • torch tensors
  • pandas DataFrame or Series
  • scipy sparse CSR matrices
  • a dictionary of the former three
  • a list/tuple of the former three
  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns:
y_proba : numpy ndarray
save_params(f_params=None, f_optimizer=None, f_history=None)[source]

Saves the module’s parameters, history, and optimizer, not the whole object.

To save the whole object, use pickle.

f_params and f_optimizer uses PyTorchs’ save().

Parameters:
f_params : file-like object, str, None (default=None)

Path of module parameters. Pass None to not save

f_optimizer : file-like object, str, None (default=None)

Path of optimizer. Pass None to not save

f_history : file-like object, str, None (default=None)

Path to history. Pass None to not save

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
>>>                    f_optimizer='optimizer.pkl',
>>>                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
>>>                   f_optimizer='optimizer.pkl',
>>>                   f_history='history.json')
set_params(**kwargs)[source]

Set the parameters of this class.

Valid parameter keys can be listed with get_params().

Returns:
self
train_step(Xi, yi, **fit_params)[source]

Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.

Loss function callable as required by some optimizers (and accepted by all of them): https://pytorch.org/docs/master/optim.html#optimizer-step-closure

The module is set to be in train mode (e.g. dropout is applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the train_split call.

train_step_single(Xi, yi, **fit_params)[source]

Compute y_pred, loss value, and update net’s gradients.

The module is set to be in train mode (e.g. dropout is applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.

validation_step(Xi, yi, **fit_params)[source]

Perform a forward step using batched data and return the resulting loss.

The module is set to be in evaluation mode (e.g. dropout is not applied).

Parameters:
Xi : input data

A batch of the input data.

yi : target data

A batch of the target data.

**fit_params : dict

Additional parameters passed to the forward method of the module and to the self.train_split call.