skorch.net¶
Neural net base class
This is the most flexible class, not making assumptions on the kind of task being peformed. Subclass this to create more specialized and sklearn-conforming classes like NeuralNetClassifier.
- class skorch.net.NeuralNet(module, criterion, optimizer=<class 'torch.optim.sgd.SGD'>, lr=0.01, max_epochs=10, batch_size=128, iterator_train=<class 'torch.utils.data.dataloader.DataLoader'>, iterator_valid=<class 'torch.utils.data.dataloader.DataLoader'>, dataset=<class 'skorch.dataset.Dataset'>, train_split=<skorch.dataset.ValidSplit object>, callbacks=None, predict_nonlinearity='auto', warm_start=False, verbose=1, device='cpu', compile=False, use_caching='auto', torch_load_kwargs=None, **kwargs)[source]¶
NeuralNet base class.
The base class covers more generic cases. Depending on your use case, you might want to use
NeuralNetClassifierorNeuralNetRegressor.In addition to the parameters listed below, there are parameters with specific prefixes that are handled separately. To illustrate this, here is an example:
>>> net = NeuralNet( ... ..., ... optimizer=torch.optimizer.SGD, ... optimizer__momentum=0.95, ...)
This way, when
optimizeris initialized,NeuralNetwill take care of setting themomentumparameter to 0.95.(Note that the double underscore notation in
optimizer__momentummeans that the parametermomentumshould be set on the objectoptimizer. This is the same semantic as used by sklearn.)Furthermore, this allows to change those parameters later:
net.set_params(optimizer__momentum=0.99)This can be useful when you want to change certain parameters using a callback, when using the net in an sklearn grid search, etc.
By default an
EpochTimer,BatchScoring(for both training and validation datasets), andPrintLogcallbacks are added for convenience.- Parameters
- moduletorch module (class or instance)
A PyTorch
Module. In general, the uninstantiated class should be passed, although instantiated modules will also work.- criteriontorch criterion (class)
The uninitialized criterion (loss) used to optimize the module.
- optimizertorch optim (class, default=torch.optim.SGD)
The uninitialized optimizer (update rule) used to optimize the module
- lrfloat (default=0.01)
Learning rate passed to the optimizer. You may use
lrinstead of usingoptimizer__lr, which would result in the same outcome.- max_epochsint (default=10)
The number of epochs to train for each
fitcall. Note that you may keyboard-interrupt training at any time.- batch_sizeint (default=128)
Mini-batch size. Use this instead of setting
iterator_train__batch_sizeanditerator_test__batch_size, which would result in the same outcome. Ifbatch_sizeis -1, a single batch with all the data will be used during training and validation.- iterator_traintorch DataLoader
The default PyTorch
DataLoaderused for training data.- iterator_validtorch DataLoader
The default PyTorch
DataLoaderused for validation and test data, i.e. during inference.- datasettorch Dataset (default=skorch.dataset.Dataset)
The dataset is necessary for the incoming data to work with pytorch’s
DataLoader. It has to implement the__len__and__getitem__methods. The provided dataset should be capable of dealing with a lot of data types out of the box, so only change this if your data is not supported. You should generally pass the uninitializedDatasetclass and define additional arguments to X and y by prefixing them withdataset__. It is also possible to pass an initialzedDataset, in which case no additional arguments may be passed.- train_splitNone or callable (default=skorch.dataset.ValidSplit(5))
If
None, there is no train/validation split. Else,train_splitshould be a function or callable that is called with X and y data and should return the tupledataset_train, dataset_valid. The validation data may beNone.- callbacksNone, “disable”, or list of Callback instances (default=None)
Which callbacks to enable. There are three possible values:
If
callbacks=None, only use default callbacks, those returned byget_default_callbacks.If
callbacks="disable", disable all callbacks, i.e. do not run any of the callbacks, not even the default callbacks.If
callbacksis a list of callbacks, use those callbacks in addition to the default callbacks. Each callback should be an instance ofCallback.Callback names are inferred from the class name. Name conflicts are resolved by appending a count suffix starting with 1, e.g.
EpochScoring_1. Alternatively, a tuple(name, callback)can be passed, wherenameshould be unique. Callbacks may or may not be instantiated. The callback name can be used to set parameters on specific callbacks (e.g., for the callback with name'print_log', usenet.set_params(callbacks__print_log__keys_ignored=['epoch', 'train_loss'])).- predict_nonlinearitycallable, None, or ‘auto’ (default=’auto’)
The nonlinearity to be applied to the prediction. When set to ‘auto’, infers the correct nonlinearity based on the criterion (softmax for
CrossEntropyLossand sigmoid forBCEWithLogitsLoss). If it cannot be inferred or if the parameter is None, just use the identity function. Don’t pass a lambda function if you want the net to be pickleable.In case a callable is passed, it should accept the output of the module (the first output if there is more than one), which is a PyTorch tensor, and return the transformed PyTorch tensor.
This can be useful, e.g., when
predict_proba()should return probabilities but a criterion is used that does not expect probabilities. In that case, the module can return whatever is required by the criterion and thepredict_nonlinearitytransforms this output into probabilities.The nonlinearity is applied only when calling
predict()orpredict_proba()but not anywhere else – notably, the loss is unaffected by this nonlinearity.- warm_startbool (default=False)
Whether each fit call should lead to a re-initialization of the module (cold start) or whether the module should be trained further (warm start).
- verboseint (default=1)
This parameter controls how much print output is generated by the net and its callbacks. By setting this value to 0, e.g. the summary scores at the end of each epoch are no longer printed. This can be useful when running a hyperparameter search. The summary scores are always logged in the history attribute, regardless of the verbose setting.
- devicestr, torch.device, or None (default=’cpu’)
The compute device to be used. If set to ‘cuda’ in order to use GPU acceleration, data in torch tensors will be pushed to cuda tensors before being sent to the module. If set to None, then all compute devices will be left unmodified.
- compilebool (default=False)
If set to
True, compile all modules usingtorch.compile. For this to work, the installed torch version has to supporttorch.compile. Compiled modules should work identically to non-compiled modules but should run faster on new GPU architectures (Volta and Ampere for instance). Additional arguments fortorch.compilecan be passed using the dunder notation, e.g. when initializing the net withcompile__dynamic=True,torch.compilewill be called withdynamic=True.- use_cachingbool or ‘auto’ (default=’auto’)
Optionally override the caching behavior of scoring callbacks. Callbacks such as
EpochScoringandBatchScoringallow to cache the inference call to save time when calculating scores during training at the expense of memory. In certain situations, e.g. when memory is tight, you may want to disable caching. As it is cumbersome to change the setting on each callback individually, this parameter allows to override their behavior globally. By default ('auto'), the callbacks will determine if caching is used or not. If this argument is set toFalse, caching will be disabled on all callbacks. If set toTrue, caching will be enabled on all callbacks. Implementation note: It is the job of the callbacks to honor this setting.- torch_load_kwargsdict or None (default=None)
Additional arguments that will be passed to torch.load when load pickled parameters.
In particular, this is important to because PyTorch will switch (probably in version 2.6.0) to only allow weights to be loaded for security reasons (i.e weights_only switches from False to True). As a consequence, loading pickled parameters may raise an error after upgrading torch because some types are used that are considered insecure. In skorch, we will also make that switch at the same time. To resolve the error, follow the instructions in the torch error message to designate the offending types as secure. Only do this if you trust the source of the file.
If you want to keep loading non-weight types the same way as before, please pass:
torch_load_kwargs={‘weights_only’: False}
You should be aware that this is considered insecure and should only be used if you trust the source of the file. However, this does not introduce new insecurities, it rather corresponds to the status quo from before torch made the switch.
Another way to avoid this issue is to pass use_safetensors=True when calling save_params and load_params. This avoid using pickle in favor of the safetensors format, which is secure by design.
- Attributes
- prefixes_list of str
Contains the prefixes to special parameters. E.g., since there is the
'optimizer'prefix, it is possible to set parameters like so:NeuralNet(..., optimizer__momentum=0.95). Some prefixes are populated dynamically, based on what modules and criteria are defined.- cuda_dependent_attributes_list of str
Contains a list of all attribute prefixes whose values depend on a CUDA device. If a
NeuralNettrained with a CUDA-enabled device is unpickled on a machine without CUDA or with CUDA disabled, the listed attributes are mapped to CPU. Expand this list if you want to add other cuda-dependent attributes.- initialized_bool
Whether the
NeuralNetwas initialized.- module_torch module (instance)
The instantiated module.
- criterion_torch criterion (instance)
The instantiated criterion.
- callbacks_list of tuples
The complete (i.e. default and other), initialized callbacks, in a tuple with unique names.
- _moduleslist of str
List of names of all modules that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _criterialist of str
List of names of all criteria that are torch modules. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
- _optimizerslist of str
List of names of all optimizers. This list is collected dynamically when the net is initialized. Typically, there is no reason for a user to modify this list.
Methods
check_is_fitted([attributes])Checks whether the net is initialized
Check that the net is ready to train
evaluation_step(batch[, training])Perform a forward step to produce the output used for prediction and scoring.
fit(X[, y])Initialize and fit the module.
fit_loop(X[, y, epochs, _routing_method])The proper fit loop.
forward(X[, training, device])Gather and concatenate the output from forward call with input data.
forward_iter(X[, training, device])Yield outputs of module forward calls on each batch of data.
Yield the learnable parameters of all modules
get_dataset(X[, y])Get a dataset that contains the input data and is passed to the iterator.
get_iterator(dataset[, training])Get an iterator that allows to loop over the batches of the given data.
get_loss(y_pred, y_true[, X, training])Return the loss for this batch.
Get metadata routing of this object.
get_params([deep])Get parameters for this estimator.
get_params_for(prefix)Collect and return init parameters for an attribute.
get_params_for_optimizer(prefix, ...)Collect and return init parameters for an optimizer.
get_split_datasets(X[, y])Get internal train and validation datasets.
Return the train step accumulator.
infer(x, **fit_params)Perform a single inference step on a batch of data.
Initializes all of its components and returns self.
Initializes all callbacks and save the result in the
callbacks_attribute.Initializes the criterion.
Initializes the history.
Initializes the module.
initialize_optimizer([triggered_directly])Initialize the model optimizer.
initialized_instance(instance_or_cls, kwargs)Return an instance initialized with the given parameters
load_params([f_params, f_optimizer, ...])Loads the the module's parameters, history, and optimizer, not the whole object.
notify(method_name, **cb_kwargs)Call the callback method specified in
method_namewith parameters specified incb_kwargs.on_batch_begin(net[, batch, training])on_epoch_begin(net[, dataset_train, ...])on_epoch_end(net[, dataset_train, dataset_valid])on_train_begin(net[, X, y])on_train_end(net[, X, y])partial_fit(X[, y, classes])Fit the module.
predict(X)Where applicable, return class labels for samples in X.
Return the output of the module's forward method as a numpy array.
run_single_epoch(iterator, training, prefix, ...)Compute a single epoch of train or validation.
save_params([f_params, f_optimizer, ...])Saves the module's parameters, history, and optimizer, not the whole object.
set_fit_request(**kwargs)Set requested parameters by the
fitmethod.set_params(**kwargs)Set the parameters of this class.
set_partial_fit_request(**kwargs)Set requested parameters by the
partial_fitmethod.torch_compile(module, name)Compile torch modules
train_step(batch, **fit_params)Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
train_step_single(batch, **fit_params)Compute y_pred, loss value, and update net's gradients.
Remove all attributes not required for prediction.
validation_step(batch, **fit_params)Perform a forward step using batched data and return the resulting loss.
check_data
get_default_callbacks
initialize_virtual_params
on_batch_end
on_grad_computed
- check_is_fitted(attributes=None, *args, **kwargs)[source]¶
Checks whether the net is initialized
- Parameters
- attributesiterable of str or None (default=None)
All the attributes that are strictly required of a fitted net. By default, this is the module_ attribute.
- Other arguments as in
- ``sklearn.utils.validation.check_is_fitted``.
- Raises
- skorch.exceptions.NotInitializedError
When the given attributes are not present.
- evaluation_step(batch, training=False)[source]¶
Perform a forward step to produce the output used for prediction and scoring.
Therefore, the module is set to evaluation mode by default beforehand which can be overridden to re-enable features like dropout by setting
training=True.- Parameters
- batch
A single batch returned by the data loader.
- trainingbool (default=False)
Whether to set the module to train mode or not.
- Returns
- y_infer
The prediction generated by the module.
- fit(X, y=None, **fit_params)[source]¶
Initialize and fit the module.
If the module was already initialized, by calling fit, the module will be re-initialized (unless
warm_startis True).- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- ytarget data, compatible with skorch.dataset.Dataset
The same data types as for
Xare supported. If your X is a Dataset that contains the target,ymay be set to None.- **fit_paramsdict
Additional parameters passed to the
forwardmethod of the module and to theself.train_splitcall.
- fit_loop(X, y=None, epochs=None, *, _routing_method='fit', **fit_params)[source]¶
The proper fit loop.
Contains the logic of what actually happens during the fit loop.
- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- ytarget data, compatible with skorch.dataset.Dataset
The same data types as for
Xare supported. If your X is a Dataset that contains the target,ymay be set to None.- epochsint or None (default=None)
If int, train for this number of epochs; if None, use
self.max_epochs.- **fit_paramsdict
Additional parameters passed to the
forwardmethod of the module and to theself.train_splitcall.
- forward(X, training=False, device='cpu')[source]¶
Gather and concatenate the output from forward call with input data.
The outputs from
self.module_.forwardare gathered on the compute device specified bydeviceand then concatenated using PyTorchcat(). If multiple outputs are returned byself.module_.forward, each one of them must be able to be concatenated this way.- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- trainingbool (default=False)
Whether to set the module to train mode or not.
- devicestring (default=’cpu’)
The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.
- Returns
- y_infertorch tensor
The result from the forward step.
- forward_iter(X, training=False, device='cpu')[source]¶
Yield outputs of module forward calls on each batch of data. The storage device of the yielded tensors is determined by the
deviceparameter.- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- trainingbool (default=False)
Whether to set the module to train mode or not.
- devicestring (default=’cpu’)
The device to store each inference result on. This defaults to CPU memory since there is genereally more memory available there. For performance reasons this might be changed to a specific CUDA device, e.g. ‘cuda:0’.
- Yields
- yptorch tensor
Result from a forward call on an individual batch.
- get_all_learnable_params()[source]¶
Yield the learnable parameters of all modules
Typically, this will yield the
named_parametersof the standard module of the net. However, if you add custom modules or if your criterion has learnable parameters, these are returned as well.If you want your optimizer to only update the parameters of some but not all modules, you should override
initialize_module()and match the corresponding modules and optimizers there:class MyNet(NeuralNet): def initialize_optimizer(self, *args, **kwargs): # first initialize the normal optimizer named_params = self.module_.named_parameters() args, kwargs = self.get_params_for_optimizer('optimizer', named_params) self.optimizer_ = self.optimizer(*args, **kwargs) # next add an another optimizer called 'optimizer2_' that is # only responsible for training 'module2_' named_params = self.module2_.named_parameters() args, kwargs = self.get_params_for_optimizer('optimizer2', named_params) self.optimizer2_ = torch.optim.SGD(*args, **kwargs) return self
- Yields
- named_parametersgenerator of parameter name and parameter
A generator over all module parameters, yielding both the name of the parameter as well as the parameter itself. Use this, for instance, to pass the named parameters to
get_params_for_optimizer().
- get_dataset(X, y=None)[source]¶
Get a dataset that contains the input data and is passed to the iterator.
Override this if you want to initialize your dataset differently.
- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- ytarget data, compatible with skorch.dataset.Dataset
The same data types as for
Xare supported. If your X is a Dataset that contains the target,ymay be set to None.
- Returns
- dataset
The initialized dataset.
- get_iterator(dataset, training=False)[source]¶
Get an iterator that allows to loop over the batches of the given data.
If
self.iterator_train__batch_sizeand/orself.iterator_test__batch_sizeare not set, useself.batch_sizeinstead.- Parameters
- datasettorch Dataset (default=skorch.dataset.Dataset)
Usually,
self.dataset, initialized with the corresponding data, is passed toget_iterator.- trainingbool (default=False)
Whether to use
iterator_trainoriterator_test.
- Returns
- iterator
An instantiated iterator that allows to loop over the mini-batches.
- get_loss(y_pred, y_true, X=None, training=False)[source]¶
Return the loss for this batch.
- Parameters
- y_predtorch tensor
Predicted target values
- y_truetorch tensor
True target values.
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- trainingbool (default=False)
Whether train mode should be used or not.
- get_metadata_routing()[source]¶
Get metadata routing of this object.
NeuralNet is both a consumer (its module’s forward method accepts arbitrary metadata) and a router (it routes metadata like
groupsto its internal CV splitter).- Returns
- routingMetadataRouter
A
MetadataRouterencapsulating routing information.
- get_params(deep=True, **kwargs)[source]¶
Get parameters for this estimator.
- Parameters
- deepbool, default=True
If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns
- paramsdict
Parameter names mapped to their values.
- get_params_for(prefix)[source]¶
Collect and return init parameters for an attribute.
Attributes could be, for instance, pytorch modules, criteria, or data loaders (for optimizers, use
get_params_for_optimizer()instead). Use the returned arguments to initialize the given attribute like this:# inside initialize_module method kwargs = self.get_params_for('module') self.module_ = self.module(**kwargs)
Proceed analogously for the criterion etc.
The reason to use this method is so that it’s possible to change the init parameters with
set_params(), which in turn makes grid search and other similar things work.Note that in general, as a user, you never have to deal with this method because
initialize_module()etc. are already taking care of this. You only need to deal with this if you overrideinitialize_module()(or similar methods) because you have some custom code that requires it.- Parameters
- prefixstr
The name of the attribute whose arguments should be returned. E.g. for the module, it should be
'module'.
- Returns
- kwargsdict
Keyword arguments to be used as init parameters.
- get_params_for_optimizer(prefix, named_parameters)[source]¶
Collect and return init parameters for an optimizer.
Parse kwargs configuration for the optimizer identified by the given prefix. Supports param group assignment using wildcards:
optimizer__lr=0.05, optimizer__param_groups=[ ('rnn*.period', {'lr': 0.3, 'momentum': 0}), ('rnn0', {'lr': 0.1}), ]
Generally, use this method like this:
# inside initialize_optimizer method named_params = self.module_.named_parameters() pgroups, kwargs = self.get_params_for_optimizer('optimizer', named_params) if 'lr' not in kwargs: kwargs['lr'] = self.lr self.optimizer_ = self.optimizer(*pgroups, **kwargs)
The reason to use this method is so that it’s possible to change the init parameters with
set_params(), which in turn makes grid search and other similar things work.Note that in general, as a user, you never have to deal with this method because
initialize_optimizer()is already taking care of this. You only need to deal with this if you overrideinitialize_optimizer()because you have some custom code that requires it.- Parameters
- prefixstr
The name of the optimizer whose arguments should be returned. Typically, this should just be
'optimizer'. There can be exceptions, however, e.g. if you want to use more than one optimizer.- named_parametersiterator
Iterator over the parameters of the module that is intended to be optimized. It’s the return value of
my_module.named_parameters().
- Returns
- argstuple
All positional arguments for this optimizer (right now only one, the parameter groups).
- kwargsdict
All other parameters for this optimizer, e.g. the learning rate.
- get_split_datasets(X, y=None, **fit_params)[source]¶
Get internal train and validation datasets.
The validation dataset can be None if
self.train_splitis set to None; then internal validation will be skipped.Override this if you want to change how the net splits incoming data into train and validation part.
- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- ytarget data, compatible with skorch.dataset.Dataset
The same data types as for
Xare supported. If your X is a Dataset that contains the target,ymay be set to None.- **fit_paramsdict
Additional parameters passed to the
self.train_splitcall.
- Returns
- dataset_train
The initialized training dataset.
- dataset_valid
The initialized validation dataset or None
- get_train_step_accumulator()[source]¶
Return the train step accumulator.
By default, the accumulator stores and retrieves the first value from the optimizer call. Most optimizers make only one call, so first value is at the same time the only value.
In case of some optimizers, e.g. LBFGS,
train_step_calc_gradientis called multiple times, as the loss function is evaluated multiple times per optimizer call. If you don’t want to return the first value in that case, override this method to return your custom accumulator.
- infer(x, **fit_params)[source]¶
Perform a single inference step on a batch of data.
- Parameters
- xinput data
A batch of the input data.
- **fit_paramsdict
Additional parameters passed to the
forwardmethod of the module and to theself.train_splitcall.
- initialize_callbacks()[source]¶
Initializes all callbacks and save the result in the
callbacks_attribute.Both
default_callbacksandcallbacksare used (in that order). Callbacks may either be initialized or not, and if they don’t have a name, the name is inferred from the class name. Theinitializemethod is called on all callbacks.The final result will be a list of tuples, where each tuple consists of a name and an initialized callback. If names are not unique, a ValueError is raised.
- initialize_criterion()[source]¶
Initializes the criterion.
If the criterion is already initialized and no parameter was changed, it will be left as is.
- initialize_module()[source]¶
Initializes the module.
If the module is already initialized and no parameter was changed, it will be left as is.
- initialize_optimizer(triggered_directly=None)[source]¶
Initialize the model optimizer. If
self.optimizer__lris not set, useself.lrinstead.- Parameters
- triggered_directly
Deprecated, don’t use it anymore.
- initialized_instance(instance_or_cls, kwargs)[source]¶
Return an instance initialized with the given parameters
This is a helper method that deals with several possibilities for a component that might need to be initialized:
It is already an instance that’s good to go
It is an instance but it needs to be re-initialized
It’s not an instance and needs to be initialized
For the majority of use cases, this comes down to just comes down to just initializing the class with its arguments.
- Parameters
- instance_or_cls
The instance or class or callable to be initialized, e.g.
self.module.- kwargsdict
The keyword arguments to initialize the instance or class. Can be an empty dict.
- Returns
- instance
The initialized component.
- load_params(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, checkpoint=None, use_safetensors=False, **kwargs)[source]¶
Loads the the module’s parameters, history, and optimizer, not the whole object.
To save and load the whole object, use pickle.
f_params,f_optimizer, etc. uses PyTorch’sload().If you’ve created a custom module, e.g.
net.mymodule_, you can save that as well by passingf_mymodule.- Parameters
- f_paramsfile-like object, str, None (default=None)
Path of module parameters. Pass
Noneto not load.- f_optimizerfile-like object, str, None (default=None)
Path of optimizer. Pass
Noneto not load.- f_criterionfile-like object, str, None (default=None)
Path of criterion. Pass
Noneto not save- f_historyfile-like object, str, None (default=None)
Path to history. Pass
Noneto not load.- checkpoint
Checkpoint, None (default=None) Checkpoint to load params from. If a checkpoint and a
f_*path is passed in, thef_*will be loaded. PassNoneto not load.- use_safetensorsbool (default=False)
Whether to use the
safetensorslibrary to load the state. By default, PyTorch is used, which in turn usespickleunder the hood. When the state was saved withsafetensors=Truewhenskorch.net.NeuralNet.save_params()was called, it should be set toTruehere as well.
Examples
>>> before = NeuralNetClassifier(mymodule) >>> before.save_params(f_params='model.pkl', >>> f_optimizer='optimizer.pkl', >>> f_history='history.json') >>> after = NeuralNetClassifier(mymodule).initialize() >>> after.load_params(f_params='model.pkl', >>> f_optimizer='optimizer.pkl', >>> f_history='history.json')
- notify(method_name, **cb_kwargs)[source]¶
Call the callback method specified in
method_namewith parameters specified incb_kwargs.Method names can be one of: * on_train_begin * on_train_end * on_epoch_begin * on_epoch_end * on_batch_begin * on_batch_end
- partial_fit(X, y=None, classes=None, **fit_params)[source]¶
Fit the module.
If the module is initialized, it is not re-initialized, which means that this method should be used if you want to continue training a model (warm start).
- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.- ytarget data, compatible with skorch.dataset.Dataset
The same data types as for
Xare supported. If your X is a Dataset that contains the target,ymay be set to None.- classesarray, sahpe (n_classes,)
Solely for sklearn compatibility, currently unused.
- **fit_paramsdict
Additional parameters passed to the
forwardmethod of the module and to theself.train_splitcall.
- predict(X)[source]¶
Where applicable, return class labels for samples in X.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()instead.- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.
- Returns
- y_prednumpy ndarray
- predict_proba(X)[source]¶
Return the output of the module’s forward method as a numpy array.
If the module’s forward method returns multiple outputs as a tuple, it is assumed that the first output contains the relevant information and the other values are ignored. If all values are relevant, consider using
forward()instead.- Parameters
- Xinput data, compatible with skorch.dataset.Dataset
By default, you should be able to pass:
numpy arrays
torch tensors
pandas DataFrame or Series
scipy sparse CSR matrices
a dictionary of the former three
a list/tuple of the former three
a Dataset
If this doesn’t work with your data, you have to pass a
Datasetthat can deal with the data.
- Returns
- y_probanumpy ndarray
- run_single_epoch(iterator, training, prefix, step_fn, **fit_params)[source]¶
Compute a single epoch of train or validation.
- Parameters
- iteratortorch DataLoader or None
The initialized
DataLoaderto loop over. If None, skip this step.- trainingbool
Whether to set the module to train mode or not.
- prefixstr
Prefix to use when saving to the history.
- step_fncallable
Function to call for each batch.
- **fit_paramsdict
Additional parameters passed to the
step_fn.
- save_params(f_params=None, f_optimizer=None, f_criterion=None, f_history=None, use_safetensors=False, **kwargs)[source]¶
Saves the module’s parameters, history, and optimizer, not the whole object.
To save the whole object, use pickle. This is necessary when you need additional learned attributes on the net, e.g. the
classes_attribute onskorch.classifier.NeuralNetClassifier.f_params,f_optimizer, etc. use PyTorch’ssave().If you’ve created a custom module, e.g.
net.mymodule_, you can save that as well by passingf_mymodule.- Parameters
- f_paramsfile-like object, str, None (default=None)
Path of module parameters. Pass
Noneto not save- f_optimizerfile-like object, str, None (default=None)
Path of optimizer. Pass
Noneto not save- f_criterionfile-like object, str, None (default=None)
Path of criterion. Pass
Noneto not save- f_historyfile-like object, str, None (default=None)
Path to history. Pass
Noneto not save- use_safetensorsbool (default=False)
Whether to use the
safetensorslibrary to persist the state. By default, PyTorch is used, which in turn usespickleunder the hood. When enablingsafetensors, be aware that only PyTorch tensors can be stored. Therefore, certain attributes like the optimizer cannot be saved.
Examples
>>> before = NeuralNetClassifier(mymodule) >>> before.save_params(f_params='model.pkl', ... f_optimizer='optimizer.pkl', ... f_history='history.json') >>> after = NeuralNetClassifier(mymodule).initialize() >>> after.load_params(f_params='model.pkl', ... f_optimizer='optimizer.pkl', ... f_history='history.json')
- set_fit_request(**kwargs)[source]¶
Set requested parameters by the
fitmethod.Please see Metadata Routing for more details.
Since
NeuralNet.fitaccepts arbitrary**fit_paramsthat are passed to the module’s forward method, metadata names cannot be inferred from the signature and must be declared explicitly using this method.- Parameters
- **kwargsdict
Arguments should be of the form
param_name=alias, wherealiascan be one of{True, False, None, str}.
- Returns
- selfobject
The updated object.
- set_params(**kwargs)[source]¶
Set the parameters of this class.
Valid parameter keys can be listed with
get_params().- Returns
- self
- set_partial_fit_request(**kwargs)[source]¶
Set requested parameters by the
partial_fitmethod.Please see Metadata Routing for more details.
Since
NeuralNet.partial_fitaccepts arbitrary**fit_paramsthat are passed to the module’s forward method, metadata names cannot be inferred from the signature and must be declared explicitly using this method.- Parameters
- **kwargsdict
Arguments should be of the form
param_name=alias, wherealiascan be one of{True, False, None, str}.
- Returns
- selfobject
The updated object.
- torch_compile(module, name)[source]¶
Compile torch modules
If
compile=Truewas set, compile all torch modules of the net. Those typically aremodule_andcriterion_, but custom modules are also included if defined.- Parameters
- moduletorch.nn.Module
The torch module to be compiled.
- namestr
The name of the module. This argument is not used but provided for convenience. You could use it, e.g., to skip compilation for specific modules.
- Returns
- moduletorch.nn.Module or torch._dynamo.OptimizedModule
The compiled module if
compile=True, otherwise the uncompiled module.
Notes
Make sure that the installed PyTorch version supports compiling (v1.14, v2.0 and higher).
- train_step(batch, **fit_params)[source]¶
Prepares a loss function callable and pass it to the optimizer, hence performing one optimization step.
Loss function callable as required by some optimizers (and accepted by all of them): https://pytorch.org/docs/master/optim.html#optimizer-step-closure
The module is set to be in train mode (e.g. dropout is applied).
- Parameters
- batch
A single batch returned by the data loader.
- **fit_paramsdict
Additional parameters passed to the
forwardmethod of the module and to the train_split call.
- Returns
- stepdict
A dictionary
{'loss': loss, 'y_pred': y_pred}, where the floatlossis the result of the loss function andy_predthe prediction generated by the PyTorch module.
- train_step_single(batch, **fit_params)[source]¶
Compute y_pred, loss value, and update net’s gradients.
The module is set to be in train mode (e.g. dropout is applied).
- Parameters
- batch
A single batch returned by the data loader.
- **fit_paramsdict
Additional parameters passed to the
forwardmethod of the module and to theself.train_splitcall.
- Returns
- stepdict
A dictionary
{'loss': loss, 'y_pred': y_pred}, where the floatlossis the result of the loss function andy_predthe prediction generated by the PyTorch module.
- trim_for_prediction()[source]¶
Remove all attributes not required for prediction.
Use this method after you finished training your net, with the goal of reducing its size. All attributes only required during training (e.g. the optimizer) are set to None. This can lead to a considerable decrease in memory footprint. It also makes it more likely that the net can be loaded with different library versions.
After calling this function, the net can only be used for prediction (e.g.
net.predictornet.predict_proba) but no longer for training (e.g.net.fit(X, y)will raise an exception).This operation is irreversible. Once the net has been trimmed for prediction, it is no longer possible to restore the original state. Morevoer, this operation mutates the net. If you need the unmodified net, create a deepcopy before trimming:
from copy import deepcopy net = NeuralNet(...) net.fit(X, y) # training finished net_original = deepcopy(net) net.trim_for_prediction() net.predict(X)
- validation_step(batch, **fit_params)[source]¶
Perform a forward step using batched data and return the resulting loss.
The module is set to be in evaluation mode (e.g. dropout is not applied).
- Parameters
- batch
A single batch returned by the data loader.
- **fit_paramsdict
Additional parameters passed to the
forwardmethod of the module and to theself.train_splitcall.