skorch.llm¶
-
class
skorch.llm.
FewShotClassifier
(model_name=None, *, model=None, tokenizer=None, prompt=None, probas_sum_to_1=True, max_samples=5, device='cpu', error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True, random_state=None)[source]¶ Few-shot classification using a Large Language Model (LLM).
This class allows you to use an LLM from Hugging Face transformers for few-shot classification. There is no training during the
fit
call, instead, the LLM will be prompted to predict the labels for each sample.Parameters: - model_name : str or None (default=None)
The name of the model to use. This is the same name as used on Hugging Face Hub. For example, to use GPT2, pass
'gpt2'
, to use the small flan-t5 model, pass'google/flan-t5-small'
. If themodel_name
parameter is passed, don’t passmodel
ortokenizer
parameters.- model : torch.nn.Module or None (default=None)
The model to use. This should be a PyTorch text generation model from Hugging Face Hub or a model with the same API. Most notably, the model should have a
generate
method. If you pass themodel
, you should also pass thetokenizer
, but you shall not pass themodel_name
.Passing the model explicitly instead of the
model_name
can have a few advantages. Most notably, this allows you to modify the model, e.g. changing its config or how the model is loaded. For instance, some models can only be loaded with the optiontrust_remote_code=True
. If using themodel_name
argument, the default settings will be used instead. Passing the model explicitly also allows you to use custom models that are not uploaded to Hugging Face Hub.- tokenizer (default=None)
A tokenizer that is compatible with the model. Typically, this is loaded using the
AutoTokenizer.from_pretrained
method provided by Hugging Face transformers. If you pass thetokenizer
, you should also pass themodel
, but you should not pass themodel_name
.- prompt : str or None (default=None)
The prompt to use. This is the text that will be passed to the model to generate the prediction. If no prompt is passed, a default prompt will be used. The prompt should be a Python string with three placeholders, one called
text
, one calledlabels
, and one calledexamples
. Thetext
placeholder will replaced by the contents fromX
passed to during inference and thelabels
placeholder will be replaced by the unique labels taken fromy
. The examples will be taken from theX
andy
seen duringfit
.An example prompt could be something like this:
“Classify this text: {text}. Possible labels are: {labels}. Here are some examples: {examples}. Your response: “.
All general tips for good prompt crafting apply here as well. Be aware that if the prompt is too long, it will exceed the context size of the model.
- probas_sum_to_1 : bool (default=True)
If
True
, then the probabilities for each sample will be normalized to sum to 1. IfFalse
, the probabilities will not be normalized.In general, without normalization, the probabilities will not sum to 1 because the LLM can generate any token, not just the labels. Since the model is restricted to only generate the available labels, there will be some probability mass that is unaccounted for. You could consider the missing probability mass to be an implicit ‘other’ class.
In general, you should set this parameter to
True
because the default assumption is that probabilities sum to 1. However, setting this toFalse
can be useful for debugging purposes, as it allows you to see how much probability the LLM assigns to different tokens. If the total probabilities are very low, it could be a sign that the LLM is not powerful enough or that the prompt is not well crafted.- max_samples: int (default=5)
The number of samples to use for few-shot learning. The few-shot samples are taken from the
X
andy
passed tofit
.This number should be large enough for the LLM to generalize, but not too large so as to exceed the context window size. More samples will also lower prediction speed.
- device : str or torch.device (default=’cpu’)
The device to use. In general, using a GPU or other accelerated hardware is advised if runtime performance is critical.
Note that if the
model
parameter is passed explicitly, the device of that model takes precedence over the value ofdevice
.- error_low_prob : {‘ignore’, ‘warn’, ‘raise’, ‘return_none’} (default=’ignore’)
Controls what should happen if the sum of the probabilities for a sample is below a given threshold. When encountering low probabilities, the options are to do one of the following:
'ignore'
: do nothing'warn'
: issue a warning'raise'
: raise an error'return_none'
: returnNone
as the prediction when calling.predict
The threshold is controlled by the
threshold_low_prob
parameter.- threshold_low_prob : float (default=0.0)
The threshold for the sum of the probabilities below which they are considered to be too low. The consequences of low probabilities are controlled by the
error_low_prob
parameter.- use_caching : bool (default=True)
If
True
, the predictions for each sample will be cached, as well as the intermediate result for each generated token. This can speed up predictions when some samples are duplicated, or when labels have a long common prefix. An example of the latter would be if a label is called “intent.support.email” and another label is called “intent.support.phone”, then the tokens for the common prefix “intent.support.” are reused for both labels, as their probabilities are identical.Note that caching is currently not supported for encoder-decoder architectures such as flan-t5. If you want to use such an architecture, turn caching off.
If you see any issues you might suspect are caused by caching, turn this option off, see if it helps, and report the issue on the skorch GitHub page.
- random_state : int, RandomState instance or None (default=None)
The choice of examples that are picked for few-shot learning is random. To fix the random seed, use this argument.
Attributes: - classes_ : ndarray of shape (n_classes, )
A list of class labels known to the classifier. This attribute can be used to identify which column in the probabilties returned by
predict_proba
corresponds to which class.
Methods
check_X_y
(X, y, **fit_params)Check that input data is well-behaved. check_prompt
(prompt)Check if the prompt is well formed. clear_model_cache
()Clear the cache of the model fit
(X, y, **fit_params)Prepare everything to enable predictions. get_examples
(X, y, n_samples)Given input data X
andy
, return a subset ofn_samples
for few-shot learning.get_metadata_routing
()Get metadata routing of this object. get_params
([deep])Get parameters for this estimator. get_prompt
(text)Return the prompt for the given sample. predict
(X)Return the classes predicted by the LLM. predict_proba
(X)Return the probabilities predicted by the LLM. score
(X, y[, sample_weight])Return the mean accuracy on the given test data and labels. set_params
(**params)Set the parameters of this estimator. set_score_request
(*, sample_weight, None, str] =)Request metadata passed to the score
method.check_args check_classes check_is_fitted -
check_prompt
(prompt)[source]¶ Check if the prompt is well formed.
If no prompt is provided, return the default prompt.
Raises: - ValueError
When the prompt is not well formed.
-
fit
(X, y, **fit_params)[source]¶ Prepare everything to enable predictions.
There is no actual fitting going on here, as the LLM is used as is. The examples used for few-shot learning will be derived from the provided input data. The selection mechanism for this is that, for each possible label, at least one example is taken from the data (if
max_samples
is large enough).To change the way that examples are selected, override the
get_examples
method.Parameters: - X : array-like of shape (n_samples,)
The input data. For zero-shot classification, this can be
None
.- y : array-like of shape (n_samples,)
The target classes. Ensure that each class that the LLM should be able to predict is present at least once. Classes that are not present during the
fit
call will never be predicted.- **fit_params : dict
Additional fitting parameters. This is mostly a placeholder for sklearn-compatibility, as there is no actual fitting process.
Returns: - self
The fitted estimator.
-
get_examples
(X, y, n_samples)[source]¶ Given input data
X
andy
, return a subset ofn_samples
for few-shot learning.This method aims at providing at least one example for each existing class.
-
set_score_request
(*, sample_weight: Union[bool, None, str] = '$UNCHANGED$') → skorch.llm.classifier.FewShotClassifier[source]¶ Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.New in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
pipeline.Pipeline
. Otherwise it has no effect.Parameters: - sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
sample_weight
parameter inscore
.
Returns: - self : object
The updated object.
-
class
skorch.llm.
ZeroShotClassifier
(model_name=None, *, model=None, tokenizer=None, prompt=None, probas_sum_to_1=True, device='cpu', error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True)[source]¶ Zero-shot classification using a Large Language Model (LLM).
This class allows you to use an LLM from Hugging Face transformers for zero-shot classification. There is no training during the
fit
call, instead, the LLM will be prompted to predict the labels for each sample.Parameters: - model_name : str or None (default=None)
The name of the model to use. This is the same name as used on Hugging Face Hub. For example, to use GPT2, pass
'gpt2'
, to use the small flan-t5 model, pass'google/flan-t5-small'
. If themodel_name
parameter is passed, don’t passmodel
ortokenizer
parameters.- model : torch.nn.Module or None (default=None)
The model to use. This should be a PyTorch text generation model from Hugging Face Hub or a model with the same API. Most notably, the model should have a
generate
method. If you pass themodel
, you should also pass thetokenizer
, but you shall not pass themodel_name
.Passing the model explicitly instead of the
model_name
can have a few advantages. Most notably, this allows you to modify the model, e.g. changing its config or how the model is loaded. For instance, some models can only be loaded with the optiontrust_remote_code=True
. If using themodel_name
argument, the default settings will be used instead. Passing the model explicitly also allows you to use custom models that are not uploaded to Hugging Face Hub.- tokenizer (default=None)
A tokenizer that is compatible with the model. Typically, this is loaded using the
AutoTokenizer.from_pretrained
method provided by Hugging Face transformers. If you pass thetokenizer
, you should also pass themodel
, but you should not pass themodel_name
.- prompt : str or None (default=None)
The prompt to use. This is the text that will be passed to the model to generate the prediction. If no prompt is passed, a default prompt will be used. The prompt should be a Python string with two placeholders, one called
text
and one calledlabels
. Thetext
placeholder will replaced by the contents fromX
and thelabels
placeholder will be replaced by the unique labels taken fromy
.An example prompt could be something like this:
“Classify this text: {text}. Possible labels are {labels}”. Your response: “
All general tips for good prompt crafting apply here as well. Be aware that if the prompt is too long, it will exceed the context size of the model.
- probas_sum_to_1 : bool (default=True)
If
True
, then the probabilities for each sample will be normalized to sum to 1. IfFalse
, the probabilities will not be normalized.In general, without normalization, the probabilities will not sum to 1 because the LLM can generate any token, not just the labels. Since the model is restricted to only generate the available labels, there will be some probability mass that is unaccounted for. You could consider the missing probability mass to be an implicit ‘other’ class.
In general, you should set this parameter to
True
because the default assumption is that probabilities sum to 1. However, setting this toFalse
can be useful for debugging purposes, as it allows you to see how much probability the LLM assigns to different tokens. If the total probabilities are very low, it could be a sign that the LLM is not powerful enough or that the prompt is not well crafted.- device : str or torch.device (default=’cpu’)
The device to use. In general, using a GPU or other accelerated hardware is advised if runtime performance is critical.
Note that if the
model
parameter is passed explicitly, the device of that model takes precedence over the value ofdevice
.- error_low_prob : {‘ignore’, ‘warn’, ‘raise’, ‘return_none’} (default=’ignore’)
Controls what should happen if the sum of the probabilities for a sample is below a given threshold. When encountering low probabilities, the options are to do one of the following:
'ignore'
: do nothing'warn'
: issue a warning'raise'
: raise an error'return_none'
: returnNone
as the prediction when calling.predict
The threshold is controlled by the
threshold_low_prob
parameter.- threshold_low_prob : float (default=0.0)
The threshold for the sum of the probabilities below which they are considered to be too low. The consequences of low probabilities are controlled by the
error_low_prob
parameter.- use_caching : bool (default=True)
If
True
, the predictions for each sample will be cached, as well as the intermediate result for each generated token. This can speed up predictions when some samples are duplicated, or when labels have a long common prefix. An example of the latter would be if a label is called “intent.support.email” and another label is called “intent.support.phone”, then the tokens for the common prefix “intent.support.” are reused for both labels, as their probabilities are identical.Note that caching is currently not supported for encoder-decoder architectures such as flan-t5. If you want to use such an architecture, turn caching off.
If you see any issues you might suspect are caused by caching, turn this option off, see if it helps, and report the issue on the skorch GitHub page.
Attributes: - classes_ : ndarray of shape (n_classes, )
A list of class labels known to the classifier. This attribute can be used to identify which column in the probabilties returned by
predict_proba
corresponds to which class.
Methods
check_X_y
(X, y, **fit_params)Check that input data is well-behaved. check_prompt
(prompt)Check if the prompt is well formed. clear_model_cache
()Clear the cache of the model fit
(X, y, **fit_params)Prepare everything to enable predictions. get_metadata_routing
()Get metadata routing of this object. get_params
([deep])Get parameters for this estimator. get_prompt
(text)Return the prompt for the given sample. predict
(X)Return the classes predicted by the LLM. predict_proba
(X)Return the probabilities predicted by the LLM. score
(X, y[, sample_weight])Return the mean accuracy on the given test data and labels. set_params
(**params)Set the parameters of this estimator. set_score_request
(*, sample_weight, None, str] =)Request metadata passed to the score
method.check_args check_classes check_is_fitted -
check_prompt
(prompt)[source]¶ Check if the prompt is well formed.
If no prompt is provided, return the default prompt.
Raises: - ValueError
When the prompt is not well formed.
-
fit
(X, y, **fit_params)[source]¶ Prepare everything to enable predictions.
There is no actual fitting going on here, as the LLM is used as is.
Parameters: - X : array-like of shape (n_samples,)
The input data. For zero-shot classification, this can be
None
.- y : array-like of shape (n_samples,)
The target classes. Ensure that each class that the LLM should be able to predict is present at least once. Classes that are not present during the
fit
call will never be predicted.- **fit_params : dict
Additional fitting parameters. This is mostly a placeholder for sklearn-compatibility, as there is no actual fitting process.
Returns: - self
The fitted estimator.
-
set_score_request
(*, sample_weight: Union[bool, None, str] = '$UNCHANGED$') → skorch.llm.classifier.ZeroShotClassifier[source]¶ Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.New in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
pipeline.Pipeline
. Otherwise it has no effect.Parameters: - sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
sample_weight
parameter inscore
.
Returns: - self : object
The updated object.