Saving and LoadingΒΆ

Skorch provides callbacks: Checkpoint, TrainEndCheckpoint, and LoadInitState to handle saving and loading models during training. To demonstrate these features, we generate a dataset and create a simple module:

import numpy as np
from sklearn.datasets import make_classification
from torch import nn

from skorch import NeuralNetClassifier

X, y = make_classification(1000, 10, n_informative=5, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Sequential):
    def __init__(self, num_units=10):
        super().__init__(
            nn.Linear(10, num_units),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(num_units, 10),
            nn.Linear(10, 2),
            nn.Softmax(dim=-1)
        )

We create a checkpoint setting dirname to 'exp1'. This will configure the checkpoint to save the model parameters, optimizer, and history into a directory named 'exp1'.

from skorch.callbacks import Checkpoint, TrainEndCheckpoint

cp = Checkpoint(dirname='exp1')
final_cp = TrainEndCheckpoint(dirname='exp1')
net = NeuralNetClassifier(
    MyModule, lr=0.5, callbacks=[cp, final_cp]
)

_ = net.fit(X, y)

# prints
  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        0.6200       0.8209        0.4765     +  0.0232
      2        0.3644       0.8557        0.3474     +  0.0238
      3        0.2875       0.8806        0.3201     +  0.0214
      4        0.2514       0.8905        0.3080     +  0.0237
      5        0.2333       0.9154        0.2844     +  0.0203
      6        0.2177       0.9403        0.2164     +  0.0215
      7        0.2194       0.9403        0.2159     +  0.0220
      8        0.2027       0.9403        0.2299        0.0202
      9        0.1864       0.9254        0.2313        0.0196
     10        0.2024       0.9353        0.2333        0.0221

By default, the checkpoint observes valid_loss and will save the model when the valid_loss is lowest. This can be seen by the + mark in the cp column of the logs.

On our first run, the validation loss did not improve after the 7th epoch. We can continue training from this checkpoint with a lower learning rate by using LoadInitState:

from skorch.callbacks import LoadInitState

cp = Checkpoint(dirname='exp1')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
    MyModule, lr=0.1, callbacks=[cp, load_state]
)

_ = net.fit(X, y)

# prints

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      8        0.1939       0.9055        0.2626     +  0.0238
      9        0.2055       0.9353        0.2031     +  0.0239
     10        0.1992       0.9453        0.2101        0.0182
     11        0.2033       0.9453        0.1947     +  0.0211
     12        0.1825       0.9104        0.2515        0.0185
     13        0.2010       0.9453        0.1927     +  0.0187
     14        0.1508       0.9453        0.1952        0.0198
     15        0.1679       0.9502        0.1905     +  0.0181
     16        0.1516       0.9453        0.1864     +  0.0192
     17        0.1576       0.9453        0.1804     +  0.0184

Since we started from the previous checkpoint which ended at epoch 7, the second run starts at epoch 8, continuing from the first checkpoint. With a lower learning rate, the validation loss was able to improve!

Notice in the first run, we included a TrainEndCheckpoint in the callbacks. This checkpoint saves the model at the end of training. This checkpoint can be passed to LoadInitState to continue training:

cp_from_final = Checkpoint(dirname='exp1', fn_prefix='from_final_')
load_state = LoadInitState(final_cp)
net = NeuralNetClassifier(
    MyModule, lr=0.1, callbacks=[cp_from_final, load_state]
)

_ = net.fit(X, y)

In this run, training started at epoch 11, continuing from the end of the first run which ended at epoch 10. We created a new checkpoint with fn_prefix set to 'from_final' to prefix the saved filenames with 'from_final' to make sure this checkpoint does not override the validation checkpoint.

Since our MyModule class allows num_units to be adjusted, we can start a new experiment by changing the dirname:

cp = Checkpoint(dirname='exp2')
load_state = LoadInitState(cp)
net = NeuralNetClassifier(
    MyModule, lr=0.5,
    callbacks=[cp, load_state],
    module__num_units=20,
)

_ = net.fit(X, y)

# print

  epoch    train_loss    valid_acc    valid_loss    cp     dur
-------  ------------  -----------  ------------  ----  ------
      1        0.5256       0.8856        0.3624     +  0.0181
      2        0.2956       0.8756        0.3416     +  0.0222
      3        0.2280       0.9453        0.2299     +  0.0211
      4        0.1948       0.9303        0.2136     +  0.0232
      5        0.1800       0.9055        0.2696        0.0223
      6        0.1605       0.9403        0.1906     +  0.0190
      7        0.1594       0.9403        0.2027        0.0184
      8        0.1319       0.9303        0.1910        0.0220
      9        0.1558       0.9254        0.1923        0.0189
     10        0.1432       0.9303        0.2219        0.0192

This stores the model into the 'exp2' directory. Since this is the first run, the LoadInitState callback does not do anything. If we were to run the above script again, the LoadInitState callback will load the model from the checkpoint.

In the run above, the last checkpoint was created at epoch 6, we can load this checkpoint to predict with it:

net = NeuralNetClassifier(
    MyModule, lr=0.5, module__num_units=20,
)
net.initialize()
net.load_params(checkpoint=cp)

y_pred = net.predict(X)

In this case, it is important to initialize the neutral net before running NeutralNet.load_params().