skorch.helper

Helper functions and classes for users.

They should not be used in skorch directly.

class skorch.helper.SliceDataset(dataset, idx=0, indices=None)[source]

Helper class that wraps a torch dataset to make it work with sklearn.

Sometimes, sklearn will touch the input data, e.g. when splitting the data for a grid search. This will fail when the input data is a torch dataset. To prevent this, use this wrapper class for your dataset.

Note: This class will only return the X value by default (i.e. the first value returned by indexing the original dataset). Sklearn, and hence skorch, always require 2 values, X and y. Therefore, you still need to provide the y data separately.

Note: This class behaves similarly to a PyTorch Subset when it is indexed by a slice or numpy array: It will return another SliceDataset that references the subset instead of the actual values. Only when it is indexed by an int does it return the actual values. The reason for this is to avoid loading all data into memory when sklearn, for instance, creates a train/validation split on the dataset. Data will only be loaded in batches during the fit loop.

Parameters:
dataset : torch.utils.data.Dataset

A valid torch dataset.

idx : int (default=0)

Indicates which element of the dataset should be returned. Typically, the dataset returns both X and y values. SliceDataset can only return 1 value. If you want to get X, choose idx=0 (default), if you want y, choose idx=1.

indices : list, np.ndarray, or None (default=None)

If you only want to return a subset of the dataset, indicate which subset that is by passing this argument. Typically, this can be left to be None, which returns all the data. See also Subset.

Examples

>>> X = MyCustomDataset()
>>> search = GridSearchCV(net, params, ...)
>>> search.fit(X, y)  # raises error
>>> ds = SliceDataset(X)
>>> search.fit(ds, y)  # works
Attributes:
shape

Methods

count(value)
index(value, [start, [stop]]) Raises ValueError if the value is not present.
transform(data) Additional transformations on data.
transform(data)[source]

Additional transformations on data.

Note: If you use this in conjuction with PyTorch DataLoader, the latter will call the dataset for each row separately, which means that the incoming data is a single rows.

class skorch.helper.SliceDict(**kwargs)[source]

Wrapper for Python dict that makes it sliceable across values.

Use this if your input data is a dictionary and you have problems with sklearn not being able to slice it. Wrap your dict with SliceDict and it should usually work.

Note: * SliceDict cannot be indexed by integers, if you want one row,

say row 3, use [3:4].
  • SliceDict accepts numpy arrays and torch tensors as values.

Examples

>>> X = {'key0': val0, 'key1': val1}
>>> search = GridSearchCV(net, params, ...)
>>> search.fit(X, y)  # raises error
>>> Xs = SliceDict(key0=val0, key1=val1)  # or Xs = SliceDict(**X)
>>> search.fit(Xs, y)  # works
Attributes:
shape

Methods

clear()
copy()
fromkeys(*args, **kwargs) fromkeys method makes no sense with SliceDict and is thus not supported.
get($self, key[, default]) Return the value for key if key is in the dictionary, else default.
items()
keys()
pop(k[,d]) If key is not found, d is returned if given, otherwise KeyError is raised
popitem() 2-tuple; but raise KeyError if D is empty.
setdefault($self, key[, default]) Insert key with a value of default if key is not in the dictionary.
update([E, ]**F) If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
values()
copy() → a shallow copy of D[source]
fromkeys(*args, **kwargs)[source]

fromkeys method makes no sense with SliceDict and is thus not supported.

update([E, ]**F) → None. Update D from dict/iterable E and F.[source]

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

skorch.helper.predefined_split(dataset)[source]

Uses dataset for validiation in NeutralNet.

Parameters:
dataset: torch Dataset

Validiation dataset

Examples

>>> valid_ds = skorch.Dataset(X, y)
>>> net = NeutralNet(..., train_split=predefined_split(valid_ds))