Saving and Loading¶
General approach¶
Pickling the whole net¶
skorch provides several ways to persist your model. First it is
possible to store the model using Python’s pickle
function. This saves the whole model, including hyperparameters. This
is useful when you don’t want to initialize your model before loading
its parameters, or when your NeuralNet
is part of an sklearn
Pipeline
:
net = NeuralNet(
module=MyModule,
criterion=torch.nn.NLLLoss,
)
model = Pipeline([
('my-features', get_features()),
('net', net),
])
model.fit(X, y)
# saving
with open('some-file.pkl', 'wb') as f:
pickle.dump(model, f)
# loading
with open('some-file.pkl', 'rb') as f:
model = pickle.load(f)
The disadvantage of pickling is that if your underlying code changes, unpickling might raise errors. Also, some Python code (e.g. lambda functions) cannot be pickled.
Pickling specific parts of the net¶
For this reason, we provide a second method for persisting your model.
To use it, call the save_params()
and
load_params()
method on
NeuralNet
. Under the hood, this saves the module
's
state_dict
, i.e. only the weights and biases of the module
.
This is more robust to changes in the code but requires you to
initialize a NeuralNet
to load the parameters again:
net = NeuralNet(
module=MyModule,
criterion=torch.nn.NLLLoss,
)
model = Pipeline([
('my-features', get_features()),
('net', net),
])
model.fit(X, y)
net.save_params(f_params='some-file.pkl')
new_net = NeuralNet(
module=MyModule,
criterion=torch.nn.NLLLoss,
)
new_net.initialize() # This is important!
new_net.load_params(f_params='some-file.pkl')
In addition to saving the model parameters, the history and optimizer
state can be saved by including the f_history and f_optimizer
keywords to save_params()
and
load_params()
on NeuralNet
. This
feature can be used to continue training:
net = NeuralNet(
module=MyModule
criterion=torch.nn.NLLLoss,
)
net.fit(X, y, epochs=2) # Train for 2 epochs
net.save_params(
f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')
new_net = NeuralNet(
module=MyModule
criterion=torch.nn.NLLLoss,
)
new_net.initialize() # This is important!
new_net.load_params(
f_params='model.pkl', f_optimizer='opt.pkl', f_history='history.json')
new_net.fit(X, y, epochs=2) # Train for another 2 epochs
Note
In order to use this feature, the history must only contain JSON encodable Python data structures. Numpy and PyTorch types should not be in the history.
Note
save_params()
does not store
learned attributes on the net. E.g.,
skorch.classifier.NeuralNetClassifier
remembers the
classes it encountered during training in the classes_
attribute. This attribute will be missing after
load_params()
. Therefore, if you need
it, you should pickle.dump()
the whole net.
Using safetensors¶
skorch also supports storing tensors using the safetensors library. There are a few
advantages to using safetensors
, which is documented on their website, but
most notably, it is secure as it does not use pickle
.
To get started, first, install the library in your virtual environment, if it’s not already installed:
python -m pip install safetensors
When using save_params()
and
load_params()
, under the hood, torch.save()
and torch.load()
are used, which rely on pickle
and are thus
unsafe. If security is a concern (or any of the other advantages of
safetensors
), you should save parameters using the use_safetensors=True
option:
net = NeuralNet(
module=MyModule
criterion=torch.nn.NLLLoss,
)
net.fit(X, y)
net.save_params(f_params='model.safetensors', use_safetensors=True)
new_net = NeuralNet(
module=MyModule
criterion=torch.nn.NLLLoss,
)
new_net.initialize() # This is important!
new_net.load_params(f_params='model.safetensors', use_safetensors=True)
One disadvantage of using safetensors
is that it can only serialize
state_dict
s of torch tensors, nothing else. Therefore, it cannot serialize
components that have other values in the state_dict
, e.g. the optimizer. If
you absolutely need to save the optimizer, you need to fall back on pickle, as
described in the previous sections.
Trimming for prediction¶
If you know that after loading the saved model, it will only be used for
prediction, not for further training, you can get rid of several components of
the net. E.g. since the optimizer, criterion, and callbacks only affect
training, they’re not needed anymore and can be removed. skorch provides a
convenience method to achieve this, called trim_for_prediction()
.
Using callbacks¶
skorch provides Checkpoint
, TrainEndCheckpoint
,
and LoadInitState
callbacks to handle saving and loading
models during training. To demonstrate these features, we generate a
dataset and create a simple module:
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
X, y = make_classification(1000, 10, n_informative=5, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
class MyModule(nn.Sequential):
def __init__(self, num_units=10):
super().__init__(
nn.Linear(10, num_units),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(num_units, 10),
nn.Linear(10, 2),
nn.Softmax(dim=-1)
)
Then we create two different checkpoint callbacks and configure them
to save the model parameters, optimizer, and history into a directory
named 'exp1'
:
# First run
from skorch.callbacks import Checkpoint, TrainEndCheckpoint
from skorch import NeuralNetClassifier
cp = Checkpoint(dirname='exp1')
train_end_cp = TrainEndCheckpoint(dirname='exp1')
net = NeuralNetClassifier(
MyModule, lr=0.5, callbacks=[cp, train_end_cp]
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
1 0.6200 0.8209 0.4765 + 0.0232
2 0.3644 0.8557 0.3474 + 0.0238
3 0.2875 0.8806 0.3201 + 0.0214
4 0.2514 0.8905 0.3080 + 0.0237
5 0.2333 0.9154 0.2844 + 0.0203
6 0.2177 0.9403 0.2164 + 0.0215
7 0.2194 0.9403 0.2159 + 0.0220
8 0.2027 0.9403 0.2299 0.0202
9 0.1864 0.9254 0.2313 0.0196
10 0.2024 0.9353 0.2333 0.0221
By default, Checkpoint
observes valid_loss
metric and
saves the model when the metric improves. This is indicated by the
+
mark in the cp
column of the logs.
On our first run, the validation loss did not improve after the 7th
epoch. We can lower the learning rate and continue training from this
checkpoint by using LoadInitState
:
from skorch.callbacks import LoadInitState
cp = Checkpoint(dirname='exp1')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
MyModule, lr=0.1, callbacks=[cp, load_state]
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
8 0.1939 0.9055 0.2626 + 0.0238
9 0.2055 0.9353 0.2031 + 0.0239
10 0.1992 0.9453 0.2101 0.0182
11 0.2033 0.9453 0.1947 + 0.0211
12 0.1825 0.9104 0.2515 0.0185
13 0.2010 0.9453 0.1927 + 0.0187
14 0.1508 0.9453 0.1952 0.0198
15 0.1679 0.9502 0.1905 + 0.0181
16 0.1516 0.9453 0.1864 + 0.0192
17 0.1576 0.9453 0.1804 + 0.0184
The LoadInitState
callback is executed once in the beginning
of the training procedure and initializes model, history, and
optimizer parameters from a specified checkpoint (if it exists). In
our case, the previous checkpoint was created at the end of epoch 7,
so the second run resumes from epoch 8. With a lower learning rate,
the validation loss was able to improve!
Notice that in the first run we included a TrainEndCheckpoint
in the list of callbacks. As its name suggests, this callback creates
a checkpoint at the end of training. As before, we can pass it to
LoadInitState
to continue training:
cp_from_final = Checkpoint(dirname='exp1', fn_prefix='from_train_end_')
load_state = LoadInitState(train_end_cp)
net = NeuralNetClassifier(
MyModule, lr=0.1, callbacks=[cp_from_final, load_state]
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
11 0.1663 0.9453 0.2166 + 0.0282
12 0.1880 0.9403 0.2237 0.0178
13 0.1813 0.9353 0.1993 + 0.0161
14 0.1744 0.9353 0.1955 + 0.0150
15 0.1538 0.9303 0.2053 0.0077
16 0.1473 0.9403 0.1947 + 0.0078
17 0.1563 0.9254 0.1989 0.0074
18 0.1558 0.9403 0.1877 + 0.0075
19 0.1534 0.9254 0.2318 0.0074
20 0.1779 0.9453 0.1814 + 0.0074
In this run, training started at epoch 11, continuing from the end of
the first run which ended at epoch 10. We created a new
Checkpoint
callback with fn_prefix
set to
'from_train_end_'
to prefix the saved filenames with
'from_train_end_'
to make sure this checkpoint does not override
the checkpoint from the previous run.
Since our MyModule
class allows num_units
to be adjusted, we
can start a new experiment by changing the dirname
:
cp = Checkpoint(dirname='exp2')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
MyModule, lr=0.5,
callbacks=[cp, load_state],
module__num_units=20,
)
_ = net.fit(X, y)
# prints
epoch train_loss valid_acc valid_loss cp dur
------- ------------ ----------- ------------ ---- ------
1 0.5256 0.8856 0.3624 + 0.0181
2 0.2956 0.8756 0.3416 + 0.0222
3 0.2280 0.9453 0.2299 + 0.0211
4 0.1948 0.9303 0.2136 + 0.0232
5 0.1800 0.9055 0.2696 0.0223
6 0.1605 0.9403 0.1906 + 0.0190
7 0.1594 0.9403 0.2027 0.0184
8 0.1319 0.9303 0.1910 0.0220
9 0.1558 0.9254 0.1923 0.0189
10 0.1432 0.9303 0.2219 0.0192
This stores the model into the 'exp2'
directory. Since this is the
first run, the LoadInitState
callback does not do anything.
If we were to run the above script again, the LoadInitState
callback will load the model from the checkpoint.
In the run above, the last checkpoint was created at epoch 6, we can load this checkpoint to predict with it:
net = NeuralNetClassifier(
MyModule, lr=0.5, module__num_units=20,
)
net.initialize()
net.load_params(checkpoint=cp)
y_pred = net.predict(X)
In this case, it is important to initialize the neural net before
running NeuralNet.load_params()
.
In general, all these callbacks also support saving with safetensors instead of pickle by passing the
argument use_safetensors=True
. The caveat applies that safetensors
can
only serialize torch tensors, hence the optimizer cannot be stored. Therefore,
pass f_optimizer=None
if you want to use safetensors
.
Remember that if you want to serialize with safetensors
, and you use both
Checkpoint
and LoadInitState
, both callbacks need to be
initialized with use_safetensors=True
.
Saving on Hugging Face Hub¶
Checkpoint
and TrainEndCheckpoint
can also be used to store
models on the Hugging Face Hub. For
this to work, instead of indicating a file name for the component to be stored,
use skorch.hf.HfHubStorage
.