skorch.net

Neural net base class

This is the most flexible class, not making assumptions on the kind of task being peformed. Subclass this to create more specialized and sklearn-conforming classes like NeuralNetClassifier.

class skorch.net.NeuralNet(module, criterion, optimizer=<class 'torch.optim.sgd.SGD'>, lr=0.01, max_epochs=10, batch_size=128, iterator_train=<class 'torch.utils.data.dataloader.DataLoader'>, iterator_valid=<class 'torch.utils.data.dataloader.DataLoader'>, dataset=<class 'skorch.dataset.Dataset'>, train_split=<skorch.dataset.ValidSplit object>, callbacks=None, predict_nonlinearity='auto', warm_start=False, verbose=1, device='cpu', compile=False, use_caching='auto', **kwargs)[source]

NeuralNet base class.

The base class covers more generic cases. Depending on your use case, you might want to use NeuralNetClassifier or NeuralNetRegressor.

In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:

>>> net = NeuralNet(
...    ...,
...    optimizer=torch.optimizer.SGD,
...    optimizer__momentum=0.95,
...)

This way, when optimizer is initialized, NeuralNet will take care of setting the momentum parameter to 0.95.

(Note that the double underscore notation in optimizer__momentum means that the parameter momentum should be set on the object optimizer. This is the same semantic as used by sklearn.)

Furthermore, this allows to change those parameters later:

net.set_params(optimizer__momentum=0.99)

This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.

By default an EpochTimer, BatchScoring (for both training and validation datasets), and PrintLog callbacks are added for convenience.

Parameters
moduletorch module (class or instance)

A PyTorch Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.

criteriontorch criterion (class)

The uninitialized criterion (loss) used to optimize the module.

optimizertorch optim (class, default=torch.optim.SGD)

The uninitialized optimizer (update rule) used to optimize the module

lrfloat (default=0.01)

Learning rate passed to the optimizer. You may use lr instead of using optimizer__lr, which would result in the same outcome.

max_epochsint (default=10)

The number of epochs to train for each fit call. Note that you may keyboard-interrupt training at any time.

batch_sizeint (default=128)

Mini-batch size. Use this instead of setting iterator_train__batch_size and iterator_test__batch_size, which would result in the same outcome. If batch_size is -1, a single batch with all the data will be used during training and validation.

iterator_traintorch DataLoader

The default PyTorch DataLoader used for training data.

iterator_validtorch DataLoader

The default PyTorch DataLoader used for validation and test data, i.e. during inference.

datasettorch Dataset (default=skorch.dataset.Dataset)

The dataset is necessary for the incoming data to work with pytorch’s DataLoader. It has to implement the __len__ and __getitem__ methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitialized Dataset class and define additional arguments to X and y by prefixing them with dataset__. It is also possible to pass an initialzed Dataset, in which case no additional arguments may be passed.

train_splitNone or callable (default=skorch.dataset.ValidSplit(5))

If None, there is no train/validation split. Else, train_split should be a function or callable that is called with X and y data and should return the tuple dataset_train, dataset_valid. The validation data may be None.

callbacksNone, “disable”, or list of Callback instances (default=None)

Which callbacks to enable. There are three possible values:

If callbacks=None, only use default callbacks, those returned by get_default_callbacks.

If callbacks="disable", disable all callbacks, i.e. do not run any of the callbacks, not even the default callbacks.

If callbacks is a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance of Callback.

Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g. EpochScoring_1. Alternatively, a tuple (name, callback) can be passed, where name should be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name 'print_log', use net.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).

predict_nonlinearitycallable, None, or ‘auto’ (default=’auto’)

The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for CrossEntropyLoss and sigmoid for BCEWithLogitsLoss). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.

In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.

This can be useful, e.g., when predict_proba() should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and the predict_nonlinearity transforms this output into probabilities.

The nonlinearity is applied only when calling predict() or predict_proba() but not anywhere else – notably, the loss is unaffected by this nonlinearity.

warm_startbool (default=False)

Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).

verboseint (default=1)

This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.

devicestr, torch.device, or None (default=’cpu’)

The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.

compilebool (default=False)

If set to True, compile all modules using torch.compile. For this to work, the installed torch version has to support torch.compile. Compiled modules should work identically to non-compiled modules but should run faster on new GPU architectures (Volta and Ampere for instance). Additional arguments for torch.compile can be passed using the dunder notation, e.g. when initializing the net with compile__dynamic=True, torch.compile will be called with dynamic=True.

use_cachingbool or ‘auto’ (default=’auto’)

Optionally override the caching behavior of scoring callbacks. Callbacks such as EpochScoring and BatchScoring allow to cache the inference call to save time when calculating scores during training at the expense of memory. In certain situations, e.g. when memory is tight, you may want to disable caching. As it is cumbersome to change the setting on each callback individually, this parameter allows to override their behavior globally. By default ('auto'), the callbacks will determine if caching is used or not. If this argument is set to False, caching will be disabled on all callbacks. If set to True, caching will be enabled on all callbacks. Implementation note: It is the job of the callbacks to honor this setting.

Attributes
prefixes_list of str

Contains the prefixes to special parameters. E.g., since there is the 'optimizer' prefix, it is possible to set parameters like so: NeuralNet(..., optimizer__momentum=0.95). Some prefixes are populated dynamically, based on what modules and criteria are defined.

cuda_dependent_attributes_list of str

Contains a list of all attribute prefixes whose values depend on a CUDA device. If a NeuralNet trained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.

initialized_bool

Whether the NeuralNet was initialized.

module_torch module (instance)

The instantiated module.

criterion_torch criterion (instance)

The instantiated criterion.

callbacks_list of tuples

The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.

_moduleslist of str

List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.

_criterialist of str

List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.

_optimizerslist of str

List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.

Methods

check_is_fitted([attributes])

Checks whether the net is initialized

check_training_readiness()

Check that the net is ready to train

evaluation_step(batch[, training])

Perform a forward step to produce the output used for prediction and scoring.

fit(X[, y])

Initialize and fit the module.

fit_loop(X[, y, epochs])

The proper fit loop.

forward(X[, training, device])

Gather and concatenate the output from forward call with input data.

forward_iter(X[, training, device])

Yield outputs of module forward calls on each batch of data.

get_all_learnable_params()

Yield the learnable parameters of all modules

get_dataset(X[, y])

Get a dataset that contains the input data and is passed to the iterator.

get_iterator(dataset[, training])

Get an iterator that allows to loop over the batches of the given data.

get_loss(y_pred, y_true[, X, training])

Return the loss for this batch.

get_params_for(prefix)

Collect and return init parameters for an attribute.

get_params_for_optimizer(prefix, ...)

Collect and return init parameters for an optimizer.

get_split_datasets(X[, y])

Get internal train and validation datasets.

get_train_step_accumulator()

Return the train step accumulator.

infer(x, **fit_params)

Perform a single inference step on a batch of data.

initialize()

Initializes all of its components and returns self.

initialize_callbacks()

Initializes all callbacks and save the result in the callbacks_ attribute.

initialize_criterion()

Initializes the criterion.

initialize_history()

Initializes the history.

initialize_module()

Initializes the module.

initialize_optimizer([triggered_directly])

Initialize the model optimizer.

initialized_instance(instance_or_cls, kwargs)

Return an instance initialized with the given parameters

load_params([f_params, f_optimizer, ...])

Loads the the module's parameters, history, and optimizer, not the whole object.

notify(method_name, **cb_kwargs)

Call the callback method specified in method_name with parameters specified in cb_kwargs.

on_batch_begin(net[, batch, training])

on_epoch_begin(net[, dataset_train, ...])

on_epoch_end(net[, dataset_train, dataset_valid])

on_train_begin(net[, X, y])

on_train_end(net[, X, y])

partial_fit(X[, y, classes])

Fit the module.

predict(X)

Where applicable, return class labels for samples in X.

predict_proba(X)

Return the output of the module's forward method as a numpy array.

run_single_epoch(iterator, training, prefix, ...)

Compute a single epoch of train or validation.

save_params([f_params, f_optimizer, ...])

Saves the module's parameters, history, and optimizer, not the whole object.

set_params(**kwargs)

Set the parameters of this class.

torch_compile(module, name)

Compile torch modules

train_step(batch, **fit_params)

Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.

train_step_single(batch, **fit_params)

Compute y_pred, loss value, and update net's gradients.

trim_for_prediction()

Remove all attributes not required for prediction.

validation_step(batch, **fit_params)

Perform a forward step using batched data and return the resulting loss.

check_data

get_default_callbacks

get_params

initialize_virtual_params

on_batch_end

on_grad_computed

check_is_fitted(attributes=None, *args, **kwargs)[source]

Checks whether the net is initialized

Parameters
attributesiterable of str or None (default=None)

All the attributes that are strictly required of a fitted net. By default, this is the module_ attribute.

Other arguments as in
``sklearn.utils.validation.check_is_fitted``.
Raises
skorch.exceptions.NotInitializedError

When the given attributes are not present.

check_training_readiness()[source]

Check that the net is ready to train

evaluation_step(batch, training=False)[source]

Perform a forward step to produce the output used for prediction and scoring.

Therefore, the module is set to evaluation mode by default beforehand which can be overridden to re-enable features like dropout by setting training=True.

Parameters
batch

A single batch returned by the data loader.

trainingbool (default=False)

Whether to set the module to train mode or not.

Returns
y_infer

The prediction generated by the module.

fit(X, y=None, **fit_params)[source]

Initialize and fit the module.

If the module was already initialized, by calling fit, the module will be re-initialized (unless warm_start is True).

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

ytarget data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_paramsdict

Additional parameters passed to the forward method of the module and to the self.train_split call.

fit_loop(X, y=None, epochs=None, **fit_params)[source]

The proper fit loop.

Contains the logic of what actually happens during the fit loop.

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

ytarget data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

epochsint or None (default=None)

If int, train for this number of epochs; if None, use self.max_epochs.

**fit_paramsdict

Additional parameters passed to the forward method of the module and to the self.train_split call.

forward(X, training=False, device='cpu')[source]

Gather and concatenate the output from forward call with input data.

The outputs from self.module_.forward are gathered on the compute device specified by device and then concatenated using PyTorch cat(). If multiple outputs are returned by self.module_.forward, each one of them must be able to be concatenated this way.

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

trainingbool (default=False)

Whether to set the module to train mode or not.

devicestring (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Returns
y_infertorch tensor

The result from the forward step.

forward_iter(X, training=False, device='cpu')[source]

Yield outputs of module forward calls on each batch of data. The storage device of the yielded tensors is determined by the device parameter.

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

trainingbool (default=False)

Whether to set the module to train mode or not.

devicestring (default=’cpu’)

The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.

Yields
yptorch tensor

Result from a forward call on an individual batch.

get_all_learnable_params()[source]

Yield the learnable parameters of all modules

Typically, this will yield the named_parameters of the standard module of the net. However, if you add custom modules or if your criterion has learnable parameters, these are returned as well.

If you want your optimizer to only update the parameters of some but not all modules, you should override initialize_module() and match the corresponding modules and optimizers there:

class MyNet(NeuralNet):

    def initialize_optimizer(self, *args, **kwargs):
        # first initialize the normal optimizer
        named_params = self.module_.named_parameters()
        args, kwargs = self.get_params_for_optimizer('optimizer', named_params)
        self.optimizer_ = self.optimizer(*args, **kwargs)

        # next add an another optimizer called 'optimizer2_' that is
        # only responsible for training 'module2_'
        named_params = self.module2_.named_parameters()
        args, kwargs = self.get_params_for_optimizer('optimizer2', named_params)
        self.optimizer2_ = torch.optim.SGD(*args, **kwargs)
        return self
Yields
named_parametersgenerator of parameter name and parameter

A generator over all module parameters, yielding both the name of the parameter as well as the parameter itself. Use this, for instance, to pass the named parameters to get_params_for_optimizer().

get_dataset(X, y=None)[source]

Get a dataset that contains the input data and is passed to the iterator.

Override this if you want to initialize your dataset differently.

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

ytarget data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

Returns
dataset

The initialized dataset.

get_iterator(dataset, training=False)[source]

Get an iterator that allows to loop over the batches of the given data.

If self.iterator_train__batch_size and/or self.iterator_test__batch_size are not set, use self.batch_size instead.

Parameters
datasettorch Dataset (default=skorch.dataset.Dataset)

Usually, self.dataset, initialized with the corresponding data, is passed to get_iterator.

trainingbool (default=False)

Whether to use iterator_train or iterator_test.

Returns
iterator

An instantiated iterator that allows to loop over the mini-batches.

get_loss(y_pred, y_true, X=None, training=False)[source]

Return the loss for this batch.

Parameters
y_predtorch tensor

Predicted target values

y_truetorch tensor

True target values.

Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

trainingbool (default=False)

Whether train mode should be used or not.

get_params_for(prefix)[source]

Collect and return init parameters for an attribute.

Attributes could be, for instance, pytorch modules, criteria, or data loaders (for optimizers, use get_params_for_optimizer() instead). Use the returned arguments to initialize the given attribute like this:

# inside initialize_module method
kwargs = self.get_params_for('module')
self.module_ = self.module(**kwargs)

Proceed analogously for the criterion etc.

The reason to use this method is so that it’s possible to change the init parameters with set_params(), which in turn makes grid search and other similar things work.

Note that in general, as a user, you never have to deal with this method because initialize_module() etc. are already taking care of this. You only need to deal with this if you override initialize_module() (or similar methods) because you have some custom code that requires it.

Parameters
prefixstr

The name of the attribute whose arguments should be returned. E.g. for the module, it should be 'module'.

Returns
kwargsdict

Keyword arguments to be used as init parameters.

get_params_for_optimizer(prefix, named_parameters)[source]

Collect and return init parameters for an optimizer.

Parse kwargs configuration for the optimizer identified by the given prefix. Supports param group assignment using wildcards:

optimizer__lr=0.05,
optimizer__param_groups=[
    ('rnn*.period', {'lr': 0.3, 'momentum': 0}),
    ('rnn0', {'lr': 0.1}),
]

Generally, use this method like this:

# inside initialize_optimizer method
named_params = self.module_.named_parameters()
pgroups, kwargs = self.get_params_for_optimizer('optimizer', named_params)
if 'lr' not in kwargs:
    kwargs['lr'] = self.lr
self.optimizer_ = self.optimizer(*pgroups, **kwargs)

The reason to use this method is so that it’s possible to change the init parameters with set_params(), which in turn makes grid search and other similar things work.

Note that in general, as a user, you never have to deal with this method because initialize_optimizer() is already taking care of this. You only need to deal with this if you override initialize_optimizer() because you have some custom code that requires it.

Parameters
prefixstr

The name of the optimizer whose arguments should be returned. Typically, this should just be 'optimizer'. There can be exceptions, however, e.g. if you want to use more than one optimizer.

named_parametersiterator

Iterator over the parameters of the module that is intended to be optimized. It’s the return value of my_module.named_parameters().

Returns
argstuple

All positional arguments for this optimizer (right now only one, the parameter groups).

kwargsdict

All other parameters for this optimizer, e.g. the learning rate.

get_split_datasets(X, y=None, **fit_params)[source]

Get internal train and validation datasets.

The validation dataset can be None if self.train_split is set to None; then internal validation will be skipped.

Override this if you want to change how the net splits incoming data into train and validation part.

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

ytarget data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

**fit_paramsdict

Additional parameters passed to the self.train_split call.

Returns
dataset_train

The initialized training dataset.

dataset_valid

The initialized validation dataset or None

get_train_step_accumulator()[source]

Return the train step accumulator.

By default, the accumulator stores and retrieves the first value from the optimizer call. Most optimizers make only one call, so first value is at the same time the only value.

In case of some optimizers, e.g. LBFGS, train_step_calc_gradient is called multiple times, as the loss function is evaluated multiple times per optimizer call. If you don’t want to return the first value in that case, override this method to return your custom accumulator.

infer(x, **fit_params)[source]

Perform a single inference step on a batch of data.

Parameters
xinput data

A batch of the input data.

**fit_paramsdict

Additional parameters passed to the forward method of the module and to the self.train_split call.

initialize()[source]

Initializes all of its components and returns self.

initialize_callbacks()[source]

Initializes all callbacks and save the result in the callbacks_ attribute.

Both default_callbacks and callbacks are used (in that order). Callbacks may either be initialized or not, and if they don’t have a name, the name is inferred from the class name. The initialize method is called on all callbacks.

The final result will be a list of tuples, where each tuple consists of a name and an initialized callback. If names are not unique, a ValueError is raised.

initialize_criterion()[source]

Initializes the criterion.

If the criterion is already initialized and no parameter was changed, it will be left as is.

initialize_history()[source]

Initializes the history.

initialize_module()[source]

Initializes the module.

If the module is already initialized and no parameter was changed, it will be left as is.

initialize_optimizer(triggered_directly=None)[source]

Initialize the model optimizer. If self.optimizer__lr is not set, use self.lr instead.

Parameters
triggered_directly

Deprecated, don’t use it anymore.

initialized_instance(instance_or_cls, kwargs)[source]

Return an instance initialized with the given parameters

This is a helper method that deals with several possibilities for a component that might need to be initialized:

  • It is already an instance that’s good to go

  • It is an instance but it needs to be re-initialized

  • It’s not an instance and needs to be initialized

For the majority of use cases, this comes down to just comes down to just initializing the class with its arguments.

Parameters
instance_or_cls

The instance or class or callable to be initialized, e.g. self.module.

kwargsdict

The keyword arguments to initialize the instance or class. Can be an empty dict.

Returns
instance

The initialized component.

load_params(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, checkpoint=None, use_safetensors=False, **kwargs)[source]

Loads the the module’s parameters, history, and optimizer, not the whole object.

To save and load the whole object, use pickle.

f_params, f_optimizer, etc. uses PyTorch’s load().

If you’ve created a custom module, e.g. net.mymodule_, you can save that as well by passing f_mymodule.

Parameters
f_paramsfile-like object, str, None (default=None)

Path of module parameters. Pass None to not load.

f_optimizerfile-like object, str, None (default=None)

Path of optimizer. Pass None to not load.

f_criterionfile-like object, str, None (default=None)

Path of criterion. Pass None to not save

f_historyfile-like object, str, None (default=None)

Path to history. Pass None to not load.

checkpointCheckpoint, None (default=None)

Checkpoint to load params from. If a checkpoint and a f_* path is passed in, the f_* will be loaded. Pass None to not load.

use_safetensorsbool (default=False)

Whether to use the safetensors library to load the state. By default, PyTorch is used, which in turn uses pickle under the hood. When the state was saved with safetensors=True when skorch.net.NeuralNet.save_params() was called, it should be set to True here as well.

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
>>>                    f_optimizer='optimizer.pkl',
>>>                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
>>>                   f_optimizer='optimizer.pkl',
>>>                   f_history='history.json')
notify(method_name, **cb_kwargs)[source]

Call the callback method specified in method_name with parameters specified in cb_kwargs.

Method names can be one of: * on_train_begin * on_train_end * on_epoch_begin * on_epoch_end * on_batch_begin * on_batch_end

partial_fit(X, y=None, classes=None, **fit_params)[source]

Fit the module.

If the module is initialized, it is not re-initialized, which means that this method should be used if you want to continue training a model (warm start).

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

ytarget data, compatible with skorch.dataset.Dataset

The same data types as for X are supported. If your X is a Dataset that contains the target, y may be set to None.

classesarray, sahpe (n_classes,)

Solely for sklearn compatibility, currently unused.

**fit_paramsdict

Additional parameters passed to the forward method of the module and to the self.train_split call.

predict(X)[source]

Where applicable, return class labels for samples in X.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns
y_prednumpy ndarray
predict_proba(X)[source]

Return the output of the module’s forward method as a numpy array.

If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using forward() instead.

Parameters
Xinput data, compatible with skorch.dataset.Dataset

By default, you should be able to pass:

  • numpy arrays

  • torch tensors

  • pandas DataFrame or Series

  • scipy sparse CSR matrices

  • a dictionary of the former three

  • a list/tuple of the former three

  • a Dataset

If this doesn’t work with your data, you have to pass a Dataset that can deal with the data.

Returns
y_probanumpy ndarray
run_single_epoch(iterator, training, prefix, step_fn, **fit_params)[source]

Compute a single epoch of train or validation.

Parameters
iteratortorch DataLoader or None

The initialized DataLoader to loop over. If None, skip this step.

trainingbool

Whether to set the module to train mode or not.

prefixstr

Prefix to use when saving to the history.

step_fncallable

Function to call for each batch.

**fit_paramsdict

Additional parameters passed to the step_fn.

save_params(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, use_safetensors=False, **kwargs)[source]

Saves the module’s parameters, history, and optimizer, not the whole object.

To save the whole object, use pickle. This is necessary when you need additional learned attributes on the net, e.g. the classes_ attribute on skorch.classifier.NeuralNetClassifier.

f_params, f_optimizer, etc. use PyTorch’s save().

If you’ve created a custom module, e.g. net.mymodule_, you can save that as well by passing f_mymodule.

Parameters
f_paramsfile-like object, str, None (default=None)

Path of module parameters. Pass None to not save

f_optimizerfile-like object, str, None (default=None)

Path of optimizer. Pass None to not save

f_criterionfile-like object, str, None (default=None)

Path of criterion. Pass None to not save

f_historyfile-like object, str, None (default=None)

Path to history. Pass None to not save

use_safetensorsbool (default=False)

Whether to use the safetensors library to persist the state. By default, PyTorch is used, which in turn uses pickle under the hood. When enabling safetensors, be aware that only PyTorch tensors can be stored. Therefore, certain attributes like the optimizer cannot be saved.

Examples

>>> before = NeuralNetClassifier(mymodule)
>>> before.save_params(f_params='model.pkl',
...                    f_optimizer='optimizer.pkl',
...                    f_history='history.json')
>>> after = NeuralNetClassifier(mymodule).initialize()
>>> after.load_params(f_params='model.pkl',
...                   f_optimizer='optimizer.pkl',
...                   f_history='history.json')
set_params(**kwargs)[source]

Set the parameters of this class.

Valid parameter keys can be listed with get_params().

Returns
self
torch_compile(module, name)[source]

Compile torch modules

If compile=True was set, compile all torch modules of the net. Those typically are module_ and criterion_, but custom modules are also included if defined.

Parameters
moduletorch.nn.Module

The torch module to be compiled.

namestr

The name of the module. This argument is not used but provided for convenience. You could use it, e.g., to skip compilation for specific modules.

Returns
moduletorch.nn.Module or torch._dynamo.OptimizedModule

The compiled module if compile=True, otherwise the uncompiled module.

Raises
ValueError

If compile=True but torch.compile is not available, raise an error.

Notes

Make sure that the installed PyTorch version supports compiling (v1.14, v2.0 and higher).

train_step(batch, **fit_params)[source]

Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.

Loss function callable as required by some optimizers (and accepted by all of them): https://pytorch.org/docs/master/optim.html#optimizer-step-closure

The module is set to be in train mode (e.g. dropout is applied).

Parameters
batch

A single batch returned by the data loader.

**fit_paramsdict

Additional parameters passed to the forward method of the module and to the train_split call.

Returns
stepdict

A dictionary {'loss': loss, 'y_pred': y_pred}, where the float loss is the result of the loss function and y_pred the prediction generated by the PyTorch module.

train_step_single(batch, **fit_params)[source]

Compute y_pred, loss value, and update net’s gradients.

The module is set to be in train mode (e.g. dropout is applied).

Parameters
batch

A single batch returned by the data loader.

**fit_paramsdict

Additional parameters passed to the forward method of the module and to the self.train_split call.

Returns
stepdict

A dictionary {'loss': loss, 'y_pred': y_pred}, where the float loss is the result of the loss function and y_pred the prediction generated by the PyTorch module.

trim_for_prediction()[source]

Remove all attributes not required for prediction.

Use this method after you finished training your net, with the goal of reducing its size. All attributes only required during training (e.g. the optimizer) are set to None. This can lead to a considerable decrease in memory footprint. It also makes it more likely that the net can be loaded with different library versions.

After calling this function, the net can only be used for prediction (e.g. net.predict or net.predict_proba) but no longer for training (e.g. net.fit(X, y) will raise an exception).

This operation is irreversible. Once the net has been trimmed for prediction, it is no longer possible to restore the original state. Morevoer, this operation mutates the net. If you need the unmodified net, create a deepcopy before trimming:

from copy import deepcopy
net = NeuralNet(...)
net.fit(X, y)
# training finished
net_original = deepcopy(net)
net.trim_for_prediction()
net.predict(X)
validation_step(batch, **fit_params)[source]

Perform a forward step using batched data and return the resulting loss.

The module is set to be in evaluation mode (e.g. dropout is not applied).

Parameters
batch

A single batch returned by the data loader.

**fit_paramsdict

Additional parameters passed to the forward method of the module and to the self.train_split call.