Contains custom skorch Dataset and ValidSplit.

class skorch.dataset.CVSplit(*args, **kwargs)[source]


__call__(dataset[, y, groups]) Call self as a function.
check_cv(y) Resolve which cross validation strategy is used.
class skorch.dataset.Dataset(X, y=None, length=None)[source]

General dataset wrapper that can be used in conjunction with PyTorch DataLoader.

The dataset will always yield a tuple of two values, the first from the data (X) and the second from the target (y). However, the target is allowed to be None. In that case, Dataset will currently return a dummy tensor, since DataLoader does not work with Nones.

Dataset currently works with the following data types:

  • numpy arrays
  • PyTorch Tensors
  • scipy sparse CSR matrices
  • pandas NDFrame
  • a dictionary of the former three
  • a list/tuple of the former three
X : see above

Everything pertaining to the input data.

y : see above or None (default=None)

Everything pertaining to the target, if there is anything.

length : int or None (default=None)

If not None, determines the length (len) of the data. Should usually be left at None, in which case the length is determined by the data itself.


transform(X, y) Additional transformations on X and y.
transform(X, y)[source]

Additional transformations on X and y.

By default, they are cast to PyTorch Tensors. Override this if you want a different behavior.

Note: If you use this in conjuction with PyTorch DataLoader, the latter will call the dataset for each row separately, which means that the incoming X and y each are single rows.

class skorch.dataset.ValidSplit(cv=5, stratified=False, random_state=None)[source]

Class that performs the internal train/valid split on a dataset.

The cv argument here works similarly to the regular sklearn cv parameter in, e.g., GridSearchCV. However, instead of cycling through all splits, only one fixed split (the first one) is used. To get a full cycle through the splits, don’t use NeuralNet’s internal validation but instead the corresponding sklearn functions (e.g. cross_val_score).

We additionally support a float, similar to sklearn’s train_test_split.

cv : int, float, cross-validation generator or an iterable, optional

(Refer sklearn’s User Guide for cross_validation for the various cross-validation strategies that can be used here.)

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 3-fold cross validation,
  • integer, to specify the number of folds in a (Stratified)KFold,
  • float, to represent the proportion of the dataset to include in the validation split.
  • An object to be used as a cross-validation generator.
  • An iterable yielding train, validation splits.
stratified : bool (default=False)

Whether the split should be stratified. Only works if y is either binary or multiclass classification.

random_state : int, RandomState instance, or None (default=None)

Control the random state in case that (Stratified)ShuffleSplit is used (which is when a float is passed to cv). For more information, look at the sklearn documentation of (Stratified)ShuffleSplit.


__call__(dataset[, y, groups]) Call self as a function.
check_cv(y) Resolve which cross validation strategy is used.

Resolve which cross validation strategy is used.


Unpack data returned by the net’s iterator into a 2-tuple.

If the wrong number of items is returned, raise a helpful error message.