skorch.llm

class skorch.llm.FewShotClassifier(model_name=None, *, model=None, tokenizer=None, prompt=None, probas_sum_to_1=True, max_samples=5, device='cpu', error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True, random_state=None)[source]

Few-shot classification using a Large Language Model (LLM).

This class allows you to use an LLM from Hugging Face transformers for few-shot classification. There is no training during the fit call, instead, the LLM will be prompted to predict the labels for each sample.

Parameters:
model_name : str or None (default=None)

The name of the model to use. This is the same name as used on Hugging Face Hub. For example, to use GPT2, pass 'gpt2', to use the small flan-t5 model, pass 'google/flan-t5-small'. If the model_name parameter is passed, don’t pass model or tokenizer parameters.

model : torch.nn.Module or None (default=None)

The model to use. This should be a PyTorch text generation model from Hugging Face Hub or a model with the same API. Most notably, the model should have a generate method. If you pass the model, you should also pass the tokenizer, but you shall not pass the model_name.

Passing the model explicitly instead of the model_name can have a few advantages. Most notably, this allows you to modify the model, e.g. changing its config or how the model is loaded. For instance, some models can only be loaded with the option trust_remote_code=True. If using the model_name argument, the default settings will be used instead. Passing the model explicitly also allows you to use custom models that are not uploaded to Hugging Face Hub.

tokenizer (default=None)

A tokenizer that is compatible with the model. Typically, this is loaded using the AutoTokenizer.from_pretrained method provided by Hugging Face transformers. If you pass the tokenizer, you should also pass the model, but you should not pass the model_name.

prompt : str or None (default=None)

The prompt to use. This is the text that will be passed to the model to generate the prediction. If no prompt is passed, a default prompt will be used. The prompt should be a Python string with three placeholders, one called text, one called labels, and one called examples. The text placeholder will replaced by the contents from X passed to during inference and the labels placeholder will be replaced by the unique labels taken from y. The examples will be taken from the X and y seen during fit.

An example prompt could be something like this:

“Classify this text: {text}. Possible labels are: {labels}. Here are some examples: {examples}. Your response: “.

All general tips for good prompt crafting apply here as well. Be aware that if the prompt is too long, it will exceed the context size of the model.

probas_sum_to_1 : bool (default=True)

If True, then the probabilities for each sample will be normalized to sum to 1. If False, the probabilities will not be normalized.

In general, without normalization, the probabilities will not sum to 1 because the LLM can generate any token, not just the labels. Since the model is restricted to only generate the available labels, there will be some probability mass that is unaccounted for. You could consider the missing probability mass to be an implicit ‘other’ class.

In general, you should set this parameter to True because the default assumption is that probabilities sum to 1. However, setting this to False can be useful for debugging purposes, as it allows you to see how much probability the LLM assigns to different tokens. If the total probabilities are very low, it could be a sign that the LLM is not powerful enough or that the prompt is not well crafted.

max_samples: int (default=5)

The number of samples to use for few-shot learning. The few-shot samples are taken from the X and y passed to fit.

This number should be large enough for the LLM to generalize, but not too large so as to exceed the context window size. More samples will also lower prediction speed.

device : str or torch.device (default=’cpu’)

The device to use. In general, using a GPU or other accelerated hardware is advised if runtime performance is critical.

Note that if the model parameter is passed explicitly, the device of that model takes precedence over the value of device.

error_low_prob : {‘ignore’, ‘warn’, ‘raise’, ‘return_none’} (default=’ignore’)

Controls what should happen if the sum of the probabilities for a sample is below a given threshold. When encountering low probabilities, the options are to do one of the following:

  • 'ignore': do nothing
  • 'warn': issue a warning
  • 'raise': raise an error
  • 'return_none': return None as the prediction when calling .predict

The threshold is controlled by the threshold_low_prob parameter.

threshold_low_prob : float (default=0.0)

The threshold for the sum of the probabilities below which they are considered to be too low. The consequences of low probabilities are controlled by the error_low_prob parameter.

use_caching : bool (default=True)

If True, the predictions for each sample will be cached, as well as the intermediate result for each generated token. This can speed up predictions when some samples are duplicated, or when labels have a long common prefix. An example of the latter would be if a label is called “intent.support.email” and another label is called “intent.support.phone”, then the tokens for the common prefix “intent.support.” are reused for both labels, as their probabilities are identical.

Note that caching is currently not supported for encoder-decoder architectures such as flan-t5. If you want to use such an architecture, turn caching off.

If you see any issues you might suspect are caused by caching, turn this option off, see if it helps, and report the issue on the skorch GitHub page.

random_state : int, RandomState instance or None (default=None)

The choice of examples that are picked for few-shot learning is random. To fix the random seed, use this argument.

Attributes:
classes_ : ndarray of shape (n_classes, )

A list of class labels known to the classifier. This attribute can be used to identify which column in the probabilties returned by predict_proba corresponds to which class.

Methods

check_X_y(X, y, **fit_params) Check that input data is well-behaved.
check_prompt(prompt) Check if the prompt is well formed.
clear_model_cache() Clear the cache of the model
fit(X, y, **fit_params) Prepare everything to enable predictions.
get_examples(X, y, n_samples) Given input data X and y, return a subset of n_samples for few-shot learning.
get_metadata_routing() Get metadata routing of this object.
get_params([deep]) Get parameters for this estimator.
get_prompt(text) Return the prompt for the given sample.
predict(X) Return the classes predicted by the LLM.
predict_proba(X) Return the probabilities predicted by the LLM.
score(X, y[, sample_weight]) Return the mean accuracy on the given test data and labels.
set_params(**params) Set the parameters of this estimator.
set_score_request(*, sample_weight, None, str] =) Request metadata passed to the score method.
check_args  
check_classes  
check_is_fitted  
check_X_y(X, y, **fit_params)[source]

Check that input data is well-behaved.

check_prompt(prompt)[source]

Check if the prompt is well formed.

If no prompt is provided, return the default prompt.

Raises:
ValueError

When the prompt is not well formed.

fit(X, y, **fit_params)[source]

Prepare everything to enable predictions.

There is no actual fitting going on here, as the LLM is used as is. The examples used for few-shot learning will be derived from the provided input data. The selection mechanism for this is that, for each possible label, at least one example is taken from the data (if max_samples is large enough).

To change the way that examples are selected, override the get_examples method.

Parameters:
X : array-like of shape (n_samples,)

The input data. For zero-shot classification, this can be None.

y : array-like of shape (n_samples,)

The target classes. Ensure that each class that the LLM should be able to predict is present at least once. Classes that are not present during the fit call will never be predicted.

**fit_params : dict

Additional fitting parameters. This is mostly a placeholder for sklearn-compatibility, as there is no actual fitting process.

Returns:
self

The fitted estimator.

get_examples(X, y, n_samples)[source]

Given input data X and y, return a subset of n_samples for few-shot learning.

This method aims at providing at least one example for each existing class.

get_prompt(text)[source]

Return the prompt for the given sample.

set_score_request(*, sample_weight: Union[bool, None, str] = '$UNCHANGED$') → skorch.llm.classifier.FewShotClassifier[source]

Request metadata passed to the score method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.
  • False: metadata is not requested and the meta-estimator will not pass it to score.
  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.
  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

New in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a pipeline.Pipeline. Otherwise it has no effect.

Parameters:
sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for sample_weight parameter in score.

Returns:
self : object

The updated object.

class skorch.llm.ZeroShotClassifier(model_name=None, *, model=None, tokenizer=None, prompt=None, probas_sum_to_1=True, device='cpu', error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True)[source]

Zero-shot classification using a Large Language Model (LLM).

This class allows you to use an LLM from Hugging Face transformers for zero-shot classification. There is no training during the fit call, instead, the LLM will be prompted to predict the labels for each sample.

Parameters:
model_name : str or None (default=None)

The name of the model to use. This is the same name as used on Hugging Face Hub. For example, to use GPT2, pass 'gpt2', to use the small flan-t5 model, pass 'google/flan-t5-small'. If the model_name parameter is passed, don’t pass model or tokenizer parameters.

model : torch.nn.Module or None (default=None)

The model to use. This should be a PyTorch text generation model from Hugging Face Hub or a model with the same API. Most notably, the model should have a generate method. If you pass the model, you should also pass the tokenizer, but you shall not pass the model_name.

Passing the model explicitly instead of the model_name can have a few advantages. Most notably, this allows you to modify the model, e.g. changing its config or how the model is loaded. For instance, some models can only be loaded with the option trust_remote_code=True. If using the model_name argument, the default settings will be used instead. Passing the model explicitly also allows you to use custom models that are not uploaded to Hugging Face Hub.

tokenizer (default=None)

A tokenizer that is compatible with the model. Typically, this is loaded using the AutoTokenizer.from_pretrained method provided by Hugging Face transformers. If you pass the tokenizer, you should also pass the model, but you should not pass the model_name.

prompt : str or None (default=None)

The prompt to use. This is the text that will be passed to the model to generate the prediction. If no prompt is passed, a default prompt will be used. The prompt should be a Python string with two placeholders, one called text and one called labels. The text placeholder will replaced by the contents from X and the labels placeholder will be replaced by the unique labels taken from y.

An example prompt could be something like this:

“Classify this text: {text}. Possible labels are {labels}”. Your response: “

All general tips for good prompt crafting apply here as well. Be aware that if the prompt is too long, it will exceed the context size of the model.

probas_sum_to_1 : bool (default=True)

If True, then the probabilities for each sample will be normalized to sum to 1. If False, the probabilities will not be normalized.

In general, without normalization, the probabilities will not sum to 1 because the LLM can generate any token, not just the labels. Since the model is restricted to only generate the available labels, there will be some probability mass that is unaccounted for. You could consider the missing probability mass to be an implicit ‘other’ class.

In general, you should set this parameter to True because the default assumption is that probabilities sum to 1. However, setting this to False can be useful for debugging purposes, as it allows you to see how much probability the LLM assigns to different tokens. If the total probabilities are very low, it could be a sign that the LLM is not powerful enough or that the prompt is not well crafted.

device : str or torch.device (default=’cpu’)

The device to use. In general, using a GPU or other accelerated hardware is advised if runtime performance is critical.

Note that if the model parameter is passed explicitly, the device of that model takes precedence over the value of device.

error_low_prob : {‘ignore’, ‘warn’, ‘raise’, ‘return_none’} (default=’ignore’)

Controls what should happen if the sum of the probabilities for a sample is below a given threshold. When encountering low probabilities, the options are to do one of the following:

  • 'ignore': do nothing
  • 'warn': issue a warning
  • 'raise': raise an error
  • 'return_none': return None as the prediction when calling .predict

The threshold is controlled by the threshold_low_prob parameter.

threshold_low_prob : float (default=0.0)

The threshold for the sum of the probabilities below which they are considered to be too low. The consequences of low probabilities are controlled by the error_low_prob parameter.

use_caching : bool (default=True)

If True, the predictions for each sample will be cached, as well as the intermediate result for each generated token. This can speed up predictions when some samples are duplicated, or when labels have a long common prefix. An example of the latter would be if a label is called “intent.support.email” and another label is called “intent.support.phone”, then the tokens for the common prefix “intent.support.” are reused for both labels, as their probabilities are identical.

Note that caching is currently not supported for encoder-decoder architectures such as flan-t5. If you want to use such an architecture, turn caching off.

If you see any issues you might suspect are caused by caching, turn this option off, see if it helps, and report the issue on the skorch GitHub page.

Attributes:
classes_ : ndarray of shape (n_classes, )

A list of class labels known to the classifier. This attribute can be used to identify which column in the probabilties returned by predict_proba corresponds to which class.

Methods

check_X_y(X, y, **fit_params) Check that input data is well-behaved.
check_prompt(prompt) Check if the prompt is well formed.
clear_model_cache() Clear the cache of the model
fit(X, y, **fit_params) Prepare everything to enable predictions.
get_metadata_routing() Get metadata routing of this object.
get_params([deep]) Get parameters for this estimator.
get_prompt(text) Return the prompt for the given sample.
predict(X) Return the classes predicted by the LLM.
predict_proba(X) Return the probabilities predicted by the LLM.
score(X, y[, sample_weight]) Return the mean accuracy on the given test data and labels.
set_params(**params) Set the parameters of this estimator.
set_score_request(*, sample_weight, None, str] =) Request metadata passed to the score method.
check_args  
check_classes  
check_is_fitted  
check_X_y(X, y, **fit_params)[source]

Check that input data is well-behaved.

check_prompt(prompt)[source]

Check if the prompt is well formed.

If no prompt is provided, return the default prompt.

Raises:
ValueError

When the prompt is not well formed.

fit(X, y, **fit_params)[source]

Prepare everything to enable predictions.

There is no actual fitting going on here, as the LLM is used as is.

Parameters:
X : array-like of shape (n_samples,)

The input data. For zero-shot classification, this can be None.

y : array-like of shape (n_samples,)

The target classes. Ensure that each class that the LLM should be able to predict is present at least once. Classes that are not present during the fit call will never be predicted.

**fit_params : dict

Additional fitting parameters. This is mostly a placeholder for sklearn-compatibility, as there is no actual fitting process.

Returns:
self

The fitted estimator.

get_prompt(text)[source]

Return the prompt for the given sample.

set_score_request(*, sample_weight: Union[bool, None, str] = '$UNCHANGED$') → skorch.llm.classifier.ZeroShotClassifier[source]

Request metadata passed to the score method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.
  • False: metadata is not requested and the meta-estimator will not pass it to score.
  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.
  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

New in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a pipeline.Pipeline. Otherwise it has no effect.

Parameters:
sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED

Metadata routing for sample_weight parameter in score.

Returns:
self : object

The updated object.