skorch.helper

Helper functions and classes for users.

They should not be used in skorch directly.

class skorch.helper.AccelerateMixin(*args, accelerator, device=None, callbacks__print_log__sink='auto', **kwargs)[source]

Mixin class to add support for huggingface accelerate

This is an experimental feature.

Use this mixin class with one of the neural net classes (e.g. NeuralNet, NeuralNetClassifier, or NeuralNetRegressor) and pass an instance of Accelerator for mixed precision, multi-GPU, or TPU training.

Install the accelerate library using:

skorch does not itself provide any facilities to enable these training features. A lot of them can still be implemented by the user with a little bit of extra work but it can be a daunting task. That is why this helper class was added: Using this mixin in conjunction with the accelerate library should cover a lot of common use cases.

Note

Under the hood, accelerate uses GradScaler, which does not support passing the training step as a closure. Therefore, if your optimizer requires that (e.g. torch.optim.LBFGS), you cannot use accelerate.

Warning

Since accelerate is still quite young and backwards compatiblity breaking features might be added, we treat its integration as an experimental feature. When accelerate’s API stabilizes, we will consider adding it to skorch proper.

Parameters:
accelerator : accelerate.Accelerator

In addition to the usual parameters, pass an instance of accelerate.Accelerator with the desired settings.

device : str, torch.device, or None (default=None)

The compute device to be used. When using accelerate, it is recommended to leave device handling to accelerate. Therefore, it is best to leave this argument to be None, which means that skorch does not set the device.

callbacks__print_log__sink : ‘auto’ or callable

If ‘auto’, uses the print function of the accelerator, if it has one. This avoids printing the same output multiple times when training concurrently on multiple machines. If the accelerator does not have a print function, use Python’s print function instead.

Examples

>>> from skorch import NeuralNetClassifier
>>> from skorch.helper import AccelerateMixin
>>> from accelerate import Accelerator
>>>
>>> class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
...     '''NeuralNetClassifier with accelerate support'''
>>>
>>> accelerator = Accelerator(...)
>>> net = AcceleratedNet(MyModule,  accelerator=accelerator)
>>> net.fit(X, y)

The same approach works with all the other skorch net classes.

Methods

on_train_end(net[, X, y])
get_iterator  
train_step_single  
class skorch.helper.DataFrameTransformer(treat_int_as_categorical=False, float_dtype=<class 'numpy.float32'>, int_dtype=<class 'numpy.int64'>)[source]

Transform a DataFrame into a dict useful for working with skorch.

Transforms cardinal data to floats and categorical data to vectors of ints so that they can be embedded.

Although skorch can deal with pandas DataFrames, the default behavior is often not very useful. Use this transformer to transform the DataFrame into a dict with all float columns concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.

Your module must have a matching signature for this to work. It must accept an argument X for all cardinal values. Additionally, for all categorical values, it must accept an argument with the same name as the corresponding column (see example below). If you need help with the required signature, use the describe_signature method of this class and pass it your data.

You can choose whether you want to treat int columns the same as float columns (default) or as categorical values.

To one-hot encode categorical features, initialize their corresponding embedding layers using the identity matrix.

Parameters:
treat_int_as_categorical : bool (default=False)

Whether to treat integers as categorical values or as cardinal values, i.e. the same as floats.

float_dtype : numpy dtype or None (default=np.float32)

The dtype to cast the cardinal values to. If None, don’t change them.

int_dtype : numpy dtype or None (default=np.int64)

The dtype to cast the categorical values to. If None, don’t change them. If you do this, it can happen that the categorical values will have different dtypes, reflecting the number of unique categories.

Notes

The value of X will always be 2-dimensional, even if it only contains 1 column.

Examples

>>> df = pd.DataFrame({
...     'col_floats': np.linspace(0, 1, 12),
...     'col_ints': [11, 11, 10] * 4,
...     'col_cats': ['a', 'b', 'a'] * 4,
... })
>>> # cast to category dtype to later learn embeddings
>>> df['col_cats'] = df['col_cats'].astype('category')
>>> y = np.asarray([0, 1, 0] * 4)
>>> class MyModule(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.reset_params()
>>>     def reset_params(self):
...         self.embedding = nn.Embedding(2, 10)
...         self.linear = nn.Linear(2, 10)
...         self.out = nn.Linear(20, 2)
...         self.nonlin = nn.Softmax(dim=-1)
>>>     def forward(self, X, col_cats):
...         # "X" contains the values from col_floats and col_ints
...         # "col_cats" contains the values from "col_cats"
...         X_lin = self.linear(X)
...         X_cat = self.embedding(col_cats)
...         X_concat = torch.cat((X_lin, X_cat), dim=1)
...         return self.nonlin(self.out(X_concat))
>>> net = NeuralNetClassifier(MyModule)
>>> pipe = Pipeline([
...     ('transform', DataFrameTransformer()),
...     ('net', net),
... ])
>>> pipe.fit(df, y)

Methods

describe_signature(df) Describe the signature required for the given data.
fit(df[, y])
fit_transform(X[, y]) Fit to data, then transform it.
get_params([deep]) Get parameters for this estimator.
set_params(**params) Set the parameters of this estimator.
transform(df) Transform DataFrame to become a dict that works well with skorch.
describe_signature(df)[source]

Describe the signature required for the given data.

Pass the DataFrame to receive a description of the signature required for the module’s forward method. The description consists of three parts:

1. The names of the arguments that the forward method needs. 2. The dtypes of the torch tensors passed to forward. 3. The number of input units that are required for the corresponding argument. For the float parameter, this is just the number of dimensions of the tensor. For categorical parameters, it is the number of unique elements.

Returns:
signature : dict

Returns a dict with each key corresponding to one key required for the forward method. The values are dictionaries of two elements. The key “dtype” describes the torch dtype of the resulting tensor, the key “input_units” describes the required number of input units.

pd = <module 'pandas' from '/home/docs/.pyenv/versions/3.7.9/lib/python3.7/site-packages/pandas/__init__.py'>[source]
transform(df)[source]

Transform DataFrame to become a dict that works well with skorch.

Parameters:
df : pd.DataFrame

Incoming DataFrame.

Returns:
X_dict: dict

Dictionary with all floats concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.

class skorch.helper.SliceDataset(dataset, idx=0, indices=None)[source]

Helper class that wraps a torch dataset to make it work with sklearn.

Sometimes, sklearn will touch the input data, e.g. when splitting the data for a grid search. This will fail when the input data is a torch dataset. To prevent this, use this wrapper class for your dataset.

Note: This class will only return the X value by default (i.e. the first value returned by indexing the original dataset). Sklearn, and hence skorch, always require 2 values, X and y. Therefore, you still need to provide the y data separately.

Note: This class behaves similarly to a PyTorch Subset when it is indexed by a slice or numpy array: It will return another SliceDataset that references the subset instead of the actual values. Only when it is indexed by an int does it return the actual values. The reason for this is to avoid loading all data into memory when sklearn, for instance, creates a train/validation split on the dataset. Data will only be loaded in batches during the fit loop.

Parameters:
dataset : torch.utils.data.Dataset

A valid torch dataset.

idx : int (default=0)

Indicates which element of the dataset should be returned. Typically, the dataset returns both X and y values. SliceDataset can only return 1 value. If you want to get X, choose idx=0 (default), if you want y, choose idx=1.

indices : list, np.ndarray, or None (default=None)

If you only want to return a subset of the dataset, indicate which subset that is by passing this argument. Typically, this can be left to be None, which returns all the data. See also Subset.

Examples

>>> X = MyCustomDataset()
>>> search = GridSearchCV(net, params, ...)
>>> search.fit(X, y)  # raises error
>>> ds = SliceDataset(X)
>>> search.fit(ds, y)  # works
Attributes:
shape

Methods

count(value)
index(value, [start, [stop]]) Raises ValueError if the value is not present.
transform(data) Additional transformations on data.
transform(data)[source]

Additional transformations on data.

Note: If you use this in conjuction with PyTorch DataLoader, the latter will call the dataset for each row separately, which means that the incoming data is a single rows.

class skorch.helper.SliceDict(**kwargs)[source]

Wrapper for Python dict that makes it sliceable across values.

Use this if your input data is a dictionary and you have problems with sklearn not being able to slice it. Wrap your dict with SliceDict and it should usually work.

Note:

  • SliceDict cannot be indexed by integers, if you want one row, say row 3, use [3:4].
  • SliceDict accepts numpy arrays and torch tensors as values.

Examples

>>> X = {'key0': val0, 'key1': val1}
>>> search = GridSearchCV(net, params, ...)
>>> search.fit(X, y)  # raises error
>>> Xs = SliceDict(key0=val0, key1=val1)  # or Xs = SliceDict(**X)
>>> search.fit(Xs, y)  # works
Attributes:
shape

Methods

clear()
copy()
fromkeys(*args, **kwargs) fromkeys method makes no sense with SliceDict and is thus not supported.
get($self, key[, default]) Return the value for key if key is in the dictionary, else default.
items()
keys()
pop(k[,d]) If key is not found, d is returned if given, otherwise KeyError is raised
popitem() 2-tuple; but raise KeyError if D is empty.
setdefault($self, key[, default]) Insert key with a value of default if key is not in the dictionary.
update([E, ]**F) If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
values()
copy() → a shallow copy of D[source]
fromkeys(*args, **kwargs)[source]

fromkeys method makes no sense with SliceDict and is thus not supported.

update([E, ]**F) → None. Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

skorch.helper.predefined_split(dataset)[source]

Uses dataset for validiation in NeuralNet.

Parameters:
dataset: torch Dataset

Validiation dataset

Examples

>>> valid_ds = skorch.dataset.Dataset(X, y)
>>> net = NeuralNet(..., train_split=predefined_split(valid_ds))