skorch.helper¶
Helper functions and classes for users.
They are intended to be used by end users but should not be depended upon for skorch-internal usage.
- class skorch.helper.DataFrameTransformer(treat_int_as_categorical=False, float_dtype=<class 'numpy.float32'>, int_dtype=<class 'numpy.int64'>)[source]¶
Transform a DataFrame into a dict useful for working with skorch.
Transforms cardinal data to floats and categorical data to vectors of ints so that they can be embedded.
Although skorch can deal with pandas DataFrames, the default behavior is often not very useful. Use this transformer to transform the DataFrame into a dict with all float columns concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.
Your module must have a matching signature for this to work. It must accept an argument
X
for all cardinal values. Additionally, for all categorical values, it must accept an argument with the same name as the corresponding column (see example below). If you need help with the required signature, use thedescribe_signature
method of this class and pass it your data.You can choose whether you want to treat int columns the same as float columns (default) or as categorical values.
To one-hot encode categorical features, initialize their corresponding embedding layers using the identity matrix.
- Parameters
- treat_int_as_categoricalbool (default=False)
Whether to treat integers as categorical values or as cardinal values, i.e. the same as floats.
- float_dtypenumpy dtype or None (default=np.float32)
The dtype to cast the cardinal values to. If None, don’t change them.
- int_dtypenumpy dtype or None (default=np.int64)
The dtype to cast the categorical values to. If None, don’t change them. If you do this, it can happen that the categorical values will have different dtypes, reflecting the number of unique categories.
Notes
The value of X will always be 2-dimensional, even if it only contains 1 column.
Examples
>>> df = pd.DataFrame({ ... 'col_floats': np.linspace(0, 1, 12), ... 'col_ints': [11, 11, 10] * 4, ... 'col_cats': ['a', 'b', 'a'] * 4, ... }) >>> # cast to category dtype to later learn embeddings >>> df['col_cats'] = df['col_cats'].astype('category') >>> y = np.asarray([0, 1, 0] * 4)
>>> class MyModule(nn.Module): ... def __init__(self): ... super().__init__() ... self.reset_params()
>>> def reset_params(self): ... self.embedding = nn.Embedding(2, 10) ... self.linear = nn.Linear(2, 10) ... self.out = nn.Linear(20, 2) ... self.nonlin = nn.Softmax(dim=-1)
>>> def forward(self, X, col_cats): ... # "X" contains the values from col_floats and col_ints ... # "col_cats" contains the values from "col_cats" ... X_lin = self.linear(X) ... X_cat = self.embedding(col_cats) ... X_concat = torch.cat((X_lin, X_cat), dim=1) ... return self.nonlin(self.out(X_concat))
>>> net = NeuralNetClassifier(MyModule) >>> pipe = Pipeline([ ... ('transform', DataFrameTransformer()), ... ('net', net), ... ]) >>> pipe.fit(df, y)
Methods
Describe the signature required for the given data.
fit
(df[, y])fit_transform
(X[, y])Fit to data, then transform it.
get_metadata_routing
()Get metadata routing of this object.
get_params
([deep])Get parameters for this estimator.
set_fit_request
(*[, df])Request metadata passed to the
fit
method.set_output
(*[, transform])Set output container.
set_params
(**params)Set the parameters of this estimator.
set_transform_request
(*[, df])Request metadata passed to the
transform
method.transform
(df)Transform DataFrame to become a dict that works well with skorch.
- describe_signature(df)[source]¶
Describe the signature required for the given data.
Pass the DataFrame to receive a description of the signature required for the module’s forward method. The description consists of three parts:
1. The names of the arguments that the forward method needs. 2. The dtypes of the torch tensors passed to forward. 3. The number of input units that are required for the corresponding argument. For the float parameter, this is just the number of dimensions of the tensor. For categorical parameters, it is the number of unique elements.
- Returns
- signaturedict
Returns a dict with each key corresponding to one key required for the forward method. The values are dictionaries of two elements. The key “dtype” describes the torch dtype of the resulting tensor, the key “input_units” describes the required number of input units.
- set_fit_request(*, df: Union[bool, None, str] = '$UNCHANGED$') DataFrameTransformer [source]¶
Request metadata passed to the
fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed tofit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it tofit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.New in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.- Parameters
- dfstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
df
parameter infit
.
- Returns
- selfobject
The updated object.
- set_transform_request(*, df: Union[bool, None, str] = '$UNCHANGED$') DataFrameTransformer [source]¶
Request metadata passed to the
transform
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed totransform
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it totransform
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.New in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.- Parameters
- dfstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
df
parameter intransform
.
- Returns
- selfobject
The updated object.
- transform(df)[source]¶
Transform DataFrame to become a dict that works well with skorch.
- Parameters
- dfpd.DataFrame
Incoming DataFrame.
- Returns
- X_dict: dict
Dictionary with all floats concatenated using the key “X” and all categorical values encoded as integers, using their respective column names as keys.
- class skorch.helper.SkorchDoctor(net, match_fn=None)[source]¶
SkorchDoctor helps you better understand and debug neural net training.
By providing records activations, gradienst, and parameter updates, as well as useful plotting functions, this class can assist you in understanding your training process and how to improve it.
This is heavily inspred by the tips from Andrej Karpathy on how to better understand the neural net and diagnose potential errors.
To use this class, initialize your skorch net and load your data as you would typically do when training a net. Then, initializea
SkorchDoctor
instance by passing said net, and fit the doctor with a small amount of data. Finally, use the records or plotting functions to help you better understand the training process.What exactly you do with this information is up to you. Some examples that come to mind:
Use the
plot_loss
figure to see if your model is powerful enough to completely overfit a small sample. If not, consider increasing its capacity, e.g. by stacking more layers or using more units per layer.Use the distribution of activations to see if some layers produce extreme values and may need a different non-linearity or some form of normalization, e.g. batch norm or layer norm. A different weight initialization scheme could also help.
Check the relative magnitude of the parameter updates to check if your learning rate is too low or too high. Maybe some layers should be frozen, or you might want to have different learning rates for different parameter groups, or an adaptive optimizer like Adam.
If gradients are too big, consider using gradient clipping. If the mangitude of gradients shifts over time, you might want to use a learning rate scheduler.
At the end of the day,
SkorchDoctor
will not tell you what you need to do to improve the training process, but it will greatly facilitate to diagnose potential problems.- Parameters
- netskorch.NeuralNet
The skorch net to be diagnosed.
- match_fncallable or None (default=None)
If
match_fn=None
, all activations, gradients, and parameter updates are recorded.If not
None
, this should be a callable/function that takes the name of a layer or parameter as input and returns a bool as output, whereFalse
indicates that this output should be excluded. As an example, if you have a module with atorch.nn.Linear
layer called"fc"
and you only want to keep records from that layer, and also not record any gradients on biases, thematch_fn
could be defined as:match_fn = lambda name: ("fc" in name) and ("bias" not in name)
Notes
Even if a train/valid split is used for the net, only training data is recorded.
Since
SkorchDoctor
will record a lot of values, you should expect an increase in memory usage and training time. However, it’s sufficient to train with a handful of samples and only a few epochs, which helps offsetting those disadvantages.After you finished the analysis, it is recommended to re-initialize the net or even better start a new process. This is because the net you passed is modified by adding hooks training it. Although there is a clean up step at the end, it’s better to start fresh when starting the real model training.
Examples
>>> net = NeuralNet(..., max_epochs=5) # a couple of epochs are enough >>> from skorch.helper import SkorchDoctor >>> doctor = SkorchDoctor(net) >>> X_sample, y_sample = X[:100], y[:100] # a few samples are enough >>> doctor.fit(X_sample, y_sample) >>> # now use the attributes and plotting functions to better >>> # understand the training process >>> doctor.activation_recs_ # the recorded activations >>> doctor.gradient_recs_ # the recorded gradients >>> doctor.param_update_recs_ # the recorded parameter updates >>> # the next steps require matplotlib to be installed >>> doctor.plot_loss() >>> doctor.plot_activations() >>> doctor.plot_gradients() >>> doctor.plot_param_updates() >>> doctor.plot_activations_over_time(<layer-name>) >>> doctor.plot_gradients_over_time(<param-name>)
- Attributes
- num_steps_int
The total number of training steps. E.g. if trained for 10 epochs with 5 batches each, that would be 50.
- module_names_list of str
All modules used by the net, typically those are
"module"
and"criterion"
.- activation_recs_: dict of list of dict of np.ndarray
The activations of each layer for each module. The outer dict contains one entry for each top level module, e.g.
module
andcriterion
. The values are lists, one entry for each training step. The entries of those lists are again dictionaries, with keys corresponding to the layer name and values to the activations of that layer, stored as a numpy array.This data structure seems to be a bit complicated at first but its use is quite straightforward. E.g. to get the activations of the layer called “dense0” of the “module” in epoch 0 and batch 0, use
doctor.activation_recs_['module'][0]['dense0']
.If an activation is not a simple array, it is disambiguated. E.g. if it’s a list, the name get a suffix of
[i]
wherei
designates the index in the list. Similary, when the output is a dict, a[key]
suffix is added, where[key]
is the key of the corresponding value in the dictionary.- gradient_recs_: dict of list of dict of np.ndarray
The gradients of each parameter for each module. The outer dict contains one entry for each top level module, e.g.
module
andcriterion
. The values are lists, one entry for each training step. The entries of those lists are dictionaries, with keys corresponding to the parameter name and values to the gradients of that parameter, stored as a numpy array. Only learnable parameters are recorded.This data structure seems to be a bit complicated at first but its use is quite straightforward. E.g. to get the gradient of the parameter called “dense0.weight” of the “module” from training step 7, use
doctor.gradient_recs_['module'][7]['dense0.weight]
.- param_update_recs_: dict of list of dict of float
The relative parameter update of each parameter for each module. The outer dict contains one entry for each top level module, e.g.
module
andcriterion
. The values are lists, one entry for each training step. The entries of those lists are dictionaries, with keys corresponding to the parameter name and values to the standard deviation of the update of that parameter, relative to the standard deviation of that parameter itself, stored as a float. Only learnable parameters are recorded.This data structure seems to be a bit complicated at first but its use is quite straightforward. E.g. to get the update of the parameter called “dense0.weight” of the “module” in the last training step, use
doctor.paramter_udpate_recs_['module'][-1]['dense0.weight]
.- fitted_bool
Whether the instance has been fitted.
Methods
fit
(X[, y])Initialize and fit the SkorchDoctor
Return the names of all layers/modules
Return all learnable parameters of the net
Initialize the SkorchDoctor
plot_activations
([step, match_fn, axes, ...])Plot the distribution of activations produced by the layers
plot_activations_over_time
(layer_name[, ...])Plot the distribution of the activation of a specific layers over time
plot_gradient_over_time
(param_name[, ...])Plot the distribution of the gradients of a specific parameter over time
plot_gradients
([step, match_fn, axes, ...])Plot the distribution of gradients of each learnable parameter
plot_loss
([ax, figsize])Plot the loss over each epoch.
plot_param_updates
([match_fn, axes, figsize])Plot the distribution of relative parameter updates.
predict
(X, **kwargs)Calls the
predict
method of the underlying netpredict_proba
(X, **kwargs)Calls the
predict_proba
method of the underlying netscore
(X[, y])Calls the
score
method of the underlying netcheck_is_fitted
- fit(X, y=None, **fit_params)[source]¶
Initialize and fit the SkorchDoctor
It is advised to use a low number of epochs and a small amount of data only, since the collection of data results in time and memory overhead.
The parameters should be exactly the same as those passed when fitting the underlying net.
- get_layer_names()[source]¶
Return the names of all layers/modules
- Returns
- namesdict of list of str
For each top level module, all layer names as a list of strings.
- get_param_names()[source]¶
Return all learnable parameters of the net
- Returns
- namesdict of list of str
For each top level module, all parameter names as a list of strings.
- initialize()[source]¶
Initialize the SkorchDoctor
This method typically does not need to be invoked explicitly, because it is called by
fit
.
- plot_activations(step=-1, match_fn=None, axes=None, histtype='step', lw=2, bins=None, density=True, figsize=None, **kwargs)[source]¶
Plot the distribution of activations produced by the layers
- Parameters
- stepint (default=-1)
Which training step to plot. By default, the last step (-1) is chosen.
- match_fncallable or None (default=None)
If not
None
, this should be a callable/function that takes the name of a layer as input and returns a bool as output, whereFalse
indicates that this layer should be excluded. Use this to filter only specific layers you want to plot.- axesnp.ndarray of AxesSubplot or None (default=None)
By default, a new matplotlib plot is created. If you instead want to plot onto an existing plot, pass it here. There should be one subplot for each top level module (typically 2).
- binsnp.ndarray or None (default=None)
Bins to use for the histogram. If left as
None
, they are inferred from the data.- **kwargs
You can override remaining plotting arguments like
lw
(line width) orfigsize
(figure size).
- Returns
- axesnp.andarray of AxesSubplot
The axes of the plot.
- plot_activations_over_time(layer_name, module_name='module', ax=None, lw=2, bins=None, figsize=None, color='k', **kwargs)[source]¶
Plot the distribution of the activation of a specific layers over time
The histograms are plotted “on top of each other” with an offset. Therefore, the absolute magnitude on the y-axis has no meaning.
- Parameters
- layer_namestr
The name of the specific layer whose activations should be plotted.
- module_namestr (default=’module’)
The name of the module that the layer belongs to. By default, it is called “module” in skorch, but it’s possible to define custom module names, in which case the corresponding name should be chosen.
- axAxesSubplot or None (default=None)
By default, a new matplotlib plot is created. If you instead want to plot onto an existing plot, pass it here. Only a single plot is created.
- binsnp.ndarray or None (default=None)
Bins to use for the histogram. If left as
None
, they are inferred from the data.- **kwargs
You can override remaining plotting arguments like
figsize
(figure size).
- Returns
- axAxesSubplot
The ax of the plot.
- plot_gradient_over_time(param_name, module_name='module', ax=None, lw=2, bins=None, figsize=None, color='k', **kwargs)[source]¶
Plot the distribution of the gradients of a specific parameter over time
The histograms are plotted “on top of each other” with an offset. Therefore, the absolute magnitude on the y-axis has no meaning.
- Parameters
- param_namestr
The name of the specific parameter that should be plotted.
- module_namestr (default=’module’)
The name of the module that the paramter belongs to. By default, it is called “module” in skorch, but it’s possible to define custom module names, in which case the corresponding name should be chosen.
- axAxesSubplot or None (default=None)
By default, a new matplotlib plot is created. If you instead want to plot onto an existing plot, pass it here. Only a single plot is created.
- binsnp.ndarray or None (default=None)
Bins to use for the histogram. If left as
None
, they are inferred from the data.- **kwargs
You can override remaining plotting arguments like
figsize
(figure size).
- Returns
- axAxesSubplot
The ax of the plot.
- plot_gradients(step=-1, match_fn=None, axes=None, histtype='step', lw=2, bins=None, density=True, figsize=None, **kwargs)[source]¶
Plot the distribution of gradients of each learnable parameter
- Parameters
- stepint (default=-1)
Which training step to plot. By default, the last step (-1) is chosen.
- match_fncallable or None (default=None)
If not
None
, this should be a callable/function that takes the name of a parameter as input and returns a bool as output, whereFalse
indicates that this layer should be excluded. Use this to filter only specific parameters you want to plot.- axesnp.ndarray of AxesSubplot or None (default=None)
By default, a new matplotlib plot is created. If you instead want to plot onto an existing plot, pass it here. There should be one subplot for each top level module (typically 2).
- binsnp.ndarray or None (default=None)
Bins to use for the histogram. If left as
None
, they are inferred from the data.- **kwargs
You can override remaining plotting arguments like
lw
(line width) orfigsize
(figure size).
- Returns
- axesnp.andarray of AxesSubplot
The axes of the plot.
- plot_loss(ax=None, figsize=None, **kwargs)[source]¶
Plot the loss over each epoch.
Plots the training loss and, if present, the validation loss over time.
- plot_param_updates(match_fn=None, axes=None, figsize=None, **kwargs)[source]¶
Plot the distribution of relative parameter updates.
Plots the log10 of the standard deviation of the parameter update relative to the parameter itself, over time. Higher values mean that the parameter changes quite a lot with each training step, lower values mean that the parameter changes little.
- Parameters
- match_fncallable or None (default=None)
If not
None
, this should be a callable/function that takes the name of a parameter as input and returns a bool as output, whereFalse
indicates that this layer should be excluded. Use this to filter only specific parameters you want to plot.- axesnp.ndarray of AxesSubplot or None (default=None)
By default, a new matplotlib plot is created. If you instead want to plot onto an existing plot, pass it here. There should be one subplot for each top level module (typically 2).
- binsnp.ndarray or None (default=None)
Bins to use for the histogram. If left as
None
, they are inferred from the data.- **kwargs
You can override remaining plotting arguments like
figsize
(figure size).
- Returns
- axesnp.andarray of AxesSubplot
The axes of the plot.
- class skorch.helper.SliceDataset(dataset, idx=0, indices=None)[source]¶
Helper class that wraps a torch dataset to make it work with sklearn.
Sometimes, sklearn will touch the input data, e.g. when splitting the data for a grid search. This will fail when the input data is a torch dataset. To prevent this, use this wrapper class for your dataset.
Note: This class will only return the X value by default (i.e. the first value returned by indexing the original dataset). Sklearn, and hence skorch, always require 2 values, X and y. Therefore, you still need to provide the y data separately.
Note: This class behaves similarly to a PyTorch
Subset
when it is indexed by a slice or numpy array: It will return anotherSliceDataset
that references the subset instead of the actual values. Only when it is indexed by an int does it return the actual values. The reason for this is to avoid loading all data into memory when sklearn, for instance, creates a train/validation split on the dataset. Data will only be loaded in batches during the fit loop.- Parameters
- datasettorch.utils.data.Dataset
A valid torch dataset.
- idxint (default=0)
Indicates which element of the dataset should be returned. Typically, the dataset returns both X and y values. SliceDataset can only return 1 value. If you want to get X, choose idx=0 (default), if you want y, choose idx=1.
- indiceslist, np.ndarray, or None (default=None)
If you only want to return a subset of the dataset, indicate which subset that is by passing this argument. Typically, this can be left to be None, which returns all the data. See also
Subset
.
Examples
>>> X = MyCustomDataset() >>> search = GridSearchCV(net, params, ...) >>> search.fit(X, y) # raises error >>> ds = SliceDataset(X) >>> search.fit(ds, y) # works
- Attributes
- shape
Methods
count
(value)index
(value, [start, [stop]])Raises ValueError if the value is not present.
transform
(data)Additional transformations on
data
.- transform(data)[source]¶
Additional transformations on
data
.Note: If you use this in conjuction with PyTorch
DataLoader
, the latter will call the dataset for each row separately, which means that the incomingdata
is a single rows.
- class skorch.helper.SliceDict(**kwargs)[source]¶
Wrapper for Python dict that makes it sliceable across values.
Use this if your input data is a dictionary and you have problems with sklearn not being able to slice it. Wrap your dict with SliceDict and it should usually work.
Note:
SliceDict cannot be indexed by integers, if you want one row, say row 3, use [3:4].
SliceDict accepts numpy arrays and torch tensors as values.
Examples
>>> X = {'key0': val0, 'key1': val1} >>> search = GridSearchCV(net, params, ...) >>> search.fit(X, y) # raises error >>> Xs = SliceDict(key0=val0, key1=val1) # or Xs = SliceDict(**X) >>> search.fit(Xs, y) # works
- Attributes
- shape
Methods
clear
()copy
()fromkeys
(*args, **kwargs)fromkeys method makes no sense with SliceDict and is thus not supported.
get
($self, key[, default])Return the value for key if key is in the dictionary, else default.
items
()keys
()pop
($self, key[, default])If key is not found, default is returned if given, otherwise KeyError is raised
popitem
($self, /)Remove and return a (key, value) pair as a 2-tuple.
setdefault
($self, key[, default])Insert key with a value of default if key is not in the dictionary.
update
([E, ]**F)If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]
values
()
- skorch.helper.parse_args(kwargs, defaults=None)[source]¶
Apply command line arguments or show help.
Use this in conjunction with the fire library to quickly build command line interfaces for your scripts.
This function returns another function that must be called with the estimator (e.g.
NeuralNet
) to apply the parsed command line arguments. If the –help option is found, show the estimator-specific help instead.- Parameters
- kwargsdict
The arguments as parsed by fire.
- defaultsdict or None (default=None)
Optionally, change the default values to use custom defaults. Commandline arguments have precedence over defaults.
- Returns
- print_help_and_exitcallable
If –help is in the arguments, print help and exit.
- set_paramscallable
If –help is not in the options, apply command line arguments to the estimator and return it.
Examples
Content of my_script.py:
>>> def main(**kwargs): >>> X, y = get_data() >>> my_model = get_model() >>> parsed = parse_args(kwargs) >>> my_model = parsed(my_model) >>> my_model.fit(X, y) >>> >>> if __name__ == '__main__': >>> fire.Fire(main)