NeuralNet

Using NeuralNet

NeuralNet and the derived classes are the main touch point for the user. They wrap the PyTorch Module while providing an interface that should be familiar for sklearn users.

Define your Module the same way as you always do. Then pass it to NeuralNet, in conjunction with a PyTorch criterion. Finally, you can call fit() and predict(), as with an sklearn estimator. The finished code could look something like this:

class MyModule(torch.nn.Module):
    ...

net = NeuralNet(
    module=MyModule,
    criterion=torch.nn.NLLLoss,
)
net.fit(X, y)
y_pred = net.predict(X_valid)

Let’s see what skorch did for us here:

  • wraps the PyTorch Module in an sklearn interface
  • converts numpy.ndarrays to PyTorch Tensors
  • abstracts away the fit loop
  • takes care of batching the data

You therefore have a lot less boilerplate code, letting you focus on what matters. At the same time, skorch is very flexible and can be extended with ease, getting out of your way as much as possible.

Initialization

In general, when you instantiate the NeuralNet instance, only the given arguments are stored. They are stored exactly as you pass them to NeuralNet. For instance, the module will remain uninstantiated. This is to make sure that the arguments you pass are not touched afterwards, which makes it possible to clone the NeuralNet instance, for instance.

Only when the fit() or initialize() method are called, are the different attributes of the net, such as the module, initialized. An initialized attribute’s name always ends on an underscore; e.g., the initialized module is called module_. (This is the same nomenclature as sklearn uses.) Thefore, you always know which attributes you set and which ones were created by NeuralNet.

The only exception is the history attribute, which is not set by the user.

Most important arguments and methods

A complete explanation of all arguments and methods of NeuralNet are found in the skorch API documentation. Here we focus on the main ones.

module

This is where you pass your PyTorch Module. Ideally, it should not be instantiated. Instead, the init arguments for your module should be passed to NeuralNet with the module__ prefix. E.g., if your module takes the arguments num_units and dropout, the code would look like this:

class MyModule(torch.nn.Module):
    def __init__(self, num_units, dropout):
        ...

net = NeuralNet(
    module=MyModule,
    module__num_units=100,
    module__dropout=0.5,
    criterion=torch.nn.NLLLoss,
)

It is, however, also possible to pass an instantiated module, e.g. a PyTorch Sequential instance.

Note that skorch does not automatically apply any nonlinearities to the outputs (except internally when determining the PyTorch NLLLoss, see below). That means that if you have a classification task, you should make sure that the final output nonlinearity is a softmax. Otherwise, when you call predict_proba(), you won’t get actual probabilities.

criterion

This should be a PyTorch (-compatible) criterion.

When you use the NeuralNetClassifier, the criterion is set to PyTorch NLLLoss by default. Furthermore, if you don’t change it loss to another criterion, NeuralNetClassifier assumes that the module returns probabilities and will automatically apply a logarithm on them (which is what NLLLoss expects).

For NeuralNetRegressor, the default criterion is PyTorch MSELoss.

After initializing the NeuralNet, the initialized criterion will stored in the criterion_ attribute.

optimizer

This should be a PyTorch optimizer, e.g. SGD. After initializing the NeuralNet, the initialized optimizer will stored in the optimizer_ attribute. During initialization you can define param groups, for example to set different learning rates for certain parameters. The parameters are selected by name with support for wildcards (globbing):

optimizer__param_groups=[
    ('embedding.*', {'lr': 0.0}),
    ('linear0.bias', {'lr': 1}),
]

Your use case may require an optimizer whose signature differs from a default PyTorch optimizer’s signature. In that case, you can define a custom function that reroutes the arguments as needed and pass it to the optimizer parameter:

# custom optimizer to encapsulate Adam
def make_lookahead(parameters, optimizer_cls, k, alpha, **kwargs):
    optimizer = optimizer_cls(parameters, **kwargs)
    return Lookahead(optimizer=optimizer, k=k, alpha=alpha)


net = NeuralNetClassifier(
        ...,
        optimizer=make_lookahead,
        optimizer__optimizer_cls=torch.optim.Adam,
        optimizer__weight_decay=1e-2,
        optimizer__k=5,
        optimizer__alpha=0.5,
        lr=1e-3)

lr

The learning rate. This argument exists for convenience, since it could also be set by optimizer__lr instead. However, it is used so often that we provided this shortcut. If you set both lr and optimizer__lr, the latter have precedence.

max_epochs

The maximum number of epochs to train with each fit() call. When you call fit(), the net will train for this many epochs, except if you interrupt training before the end (e.g. by using an early stopping callback or interrupt manually with ctrl+c).

If you want to change the number of epochs to train, you can either set a different value for max_epochs, or you call fit_loop() instead of fit() and pass the desired number of epochs explicitly:

net.fit_loop(X, y, epochs=20)

batch_size

This argument controls the batch size for iterator_train and iterator_valid at the same time. batch_size=128 is thus a convenient shortcut for explicitly typing iterator_train__batch_size=128 and iterator_valid__batch_size=128. If you set all three arguments, the latter two will have precedence.

train_split

This determines the NeuralNet’s internal train/validation split. By default, 20% of the incoming data is reserved for validation. If you set this value to None, all the data is used for training.

For more details, please look at dataset.

callbacks

For more details on the callback classes, please look at callbacks.

By default, NeuralNet and its subclasses start with a couple of useful callbacks. Those are defined in the get_default_callbacks() method and include, for instance, callbacks for measuring and printing model performance.

In addition to the default callbacks, you may provide your own callbacks. There are a couple of ways to pass callbacks to the NeuralNet instance. The easiest way is to pass a list of all your callbacks to the callbacks argument:

net = NeuralNet(
    module=MyModule,
    callbacks=[
        MyCallback1(...),
        MyCallback2(...),
    ],
)

Inside the NeuralNet instance, each callback will receive a separate name. Since we provide no name in the example above, the class name will taken, which will lead to a name collision in case of two or more callbacks of the same class. This is why it is better to initialize the callbacks with a list of tuples of name and callback instance, like this:

net = NeuralNet(
    module=MyModule,
    callbacks=[
        ('cb1', MyCallback1(...)),
        ('cb2', MyCallback2(...)),
    ],
)

This approach of passing a list of name, instance tuples should be familiar to users of sklearnPipelines and FeatureUnions.

An additonal advantage of this way of passing callbacks is that it allows to pass arguments to the callbacks by name (using the double-underscore notation):

net.set_params(callbacks__cb1__foo=123, callbacks__cb2__bar=456)

Use this, for instance, when trying out different callback parameters in a grid search.

Note: The user-defined callbacks are always called in the same order as they appeared in the list. If there are dependencies between the callbacks, the user has to make sure that the order respects them. Also note that the user-defined callbacks will be called after the default callbacks so that they can make use of the things provided by the default callbacks. The only exception is the default callback PrintLog, which is always called last.

warm_start

This argument determines whether each fit() call leads to a re-initialization of the NeuralNet or not. By default, when calling fit(), the parameters of the net are initialized, so your previous training progress is lost (consistent with the sklearn fit() calls). In contrast, with warm_start=True, each fit() call will continue from the most recent state.

device

As the name suggests, this determines which computation device should be used. If set to cuda, the incoming data will be transferred to CUDA before being passed to the PyTorch Module. The device parameter adheres to the general syntax of the PyTorch device parameter.

initialize()

As mentioned earlier, upon instantiating the NeuralNet instance, the net’s components are not yet initialized. That means, e.g., that the weights and biases of the layers are not yet set. This only happens after the initialize() call. However, when you call fit() and the net is not yet initialized, initialize() is called automatically. You thus rarely need to call it manually.

The initialize() method itself calls a couple of other initialization methods that are specific to each component. E.g., initialize_module() is responsible for initializing the PyTorch module. Therefore, if you have special needs for initializing the module, it is enough to override initialize_module(), you don’t need to override the whole initialize() method.

fit(X, y)

This is one of the main methods you will use. It contains everything required to train the model, be it batching of the data, triggering the callbacks, or handling the internal validation set.

In general, we assume there to be an X and a y. If you have more input data than just one array, it is possible for X to be a list or dictionary of data (see dataset). And if your task does not have an actual y, you may pass y=None.

If you fit with a PyTorch Dataset and don’t explicitly pass y, several components down the line might not work anymore, since sklearn sometimes requires an explicit y (e.g. for scoring). In general, PyTorch Datasets should work, though.

In addition to fit(), there is also the partial_fit() method, known from some sklearn estimators. partial_fit() allows you to continue training from your current status, even if you set warm_start=False. A further use case for partial_fit() is when your data does not fit into memory and you thus need to have several training steps.

Tip : skorch gracefully catches the KeyboardInterrupt exception. Therefore, during a training run, you can send a KeyboardInterrupt signal without the Python process exiting (typically, KeyboardInterrupt can be triggered by ctrl+c or, in a Jupyter notebook, by clicking Kernel -> Interrupt). This way, when your model has reached a good score before max_epochs have been reached, you can dynamically stop training.

predict(X) and predict_proba(X)

These methods perform an inference step on the input data and return numpy.ndarrays. By default, predict_proba() will return whatever it is that the module’s forward() method returns, cast to a numpy.ndarray. If forward() returns multiple outputs as a tuple, only the first output is used, the rest is discarded.

If the forward()-output can not be cast to a numpy.ndarray, or if you need access to all outputs in the multiple-outputs case, consider using either of forward() or forward_iter() methods to generate outputs from the module. Alternatively, you may directly call net.module_(X).

In case of NeuralNetClassifier, the predict() method tries to return the class labels by applying the argmax over the last axis of the result of predict_proba(). Obviously, this only makes sense if predict_proba() returns class probabilities. If this is not true, you should just use predict_proba().

score(X, y)

This method returns the mean accuracy on the given data and labels for classifiers and the coefficient of determination R^2 of the prediction for regressors. NeuralNet still has no score method. If you need it, you have to implement it yourself.

model persistence

In general there are different ways of saving and loading models, each with their own advantages and disadvantages. More details and usage examples can be found here: Saving and Loading.

If you would like to use pickle (the default way when using scikit-learn models), this is possible with skorch nets. This saves the whole net including hyperparameters etc. The advantage is that you can restore everything to exactly the state it was before. The disadvantage is it’s easier for code changes to break your old saves.

Additionally, it is possible to save and load specific attributes of the net, such as the module, optimizer, or history, by calling save_params() and load_params(). This is useful if you’re only interested in saving a particular part of your model, and is more robust to code changes.

Finally, it is also possible to use callbacks to save and load models, e.g. Checkpoint. Those should be used if you need to have your model saved or loaded at specific times, e.g. at the start or end of the training process.

Input data

Regular data

skorch supports numerous input types for data. Regular input types that should just work are numpy arrays, torch tensors, scipy sparse CSR matrices, and pandas DataFrames (see also DataFrameTransformer).

Typically, your task should involve an X and a y. If you’re dealing with a task that doesn’t require a target (say, training an autoencoder), you can just pass y=None. Make sure your loss function deals with this appropriately.

Datasets

Datasets are also supported, with the requirement that they should return exactly two items (X and y). For more information on that, take a look at the Dataset documentation.

Many PyTorch libraries, like torchvision, implement their own Datasets. These usually work seamlessly with skorch, as long as their __getitem__ methods return two outputs. In case they don’t, consider overriding the __getitem__ class and re-arranging the ouputs so that __getitem__ returns exactly two elements. If the original implementation returns more than two elements, take a look at the next section to get an idea how to deal with that.

Multiple input arguments

In some cases, the input actually consists of multiple inputs. E.g., in a text classification task, you might have an array that contains the integers representing the tokens for each sample, and another array containing the number of tokens of each sample. skorch has you covered here as well.

You could supply a list or tuple with all your inputs (net.fit([tokens, num_tokens], y)), but we actually recommend another approach. The best way is to pass the different arguments as a dictionary. Then the keys of that dictionary have to correspond to the argument names of your module’s forward method. Below is an example:

X_dict = {'tokens': tokens, 'num_tokens': num_tokens}

class MyModule(nn.Module):
    def forward(self, tokens, num_tokens):  # <- same names as in your dict
        ...

net = NeuralNet(MyModule, ...)
net.fit(X_dict, y)

As you can see, the forward method takes arguments with exactly the same name as the keys in the dictionary. This is how the different inputs are matched. To make this work with GridSearchCV, please use SliceDict.

Using a dict should cover most use cases that involve multiple inputs. However, it will fail if your inputs have different sizes. E.g., if your array of tokens has 1000 elements but your array of number of tokens has 2000 elements, this would fail. The main reason for this is batching: How can we know which elements of the two arrays belong in the same batch?

If your input consists of multiple inputs with different sizes, your best bet is to implement your own dataset class. That class should know how it deals with the different inputs, i.e. which elements belong to the same sample. Again, please refer to the Dataset section for more details.

Special arguments

In addition to the arguments explicitly listed for NeuralNet, there are some arguments with special prefixes, as shown below:

class MyModule(torch.nn.Module):
    def __init__(self, num_units, dropout):
        ...

net = NeuralNet(
    module=MyModule,
    module__num_units=100,
    module__dropout=0.5,
    criterion=torch.nn.NLLLoss,
    criterion__weight=weight,
    optimizer=torch.optim.SGD,
    optimizer__momentum=0.9,
)

Those arguments are used to initialize your module, criterion, etc. They are not fixed because we cannot know them in advance; in fact, you can define any parameter for your module or other components.