Module imodelsx.treeprompt.treeprompt
Expand source code
import os
import pandas as pd
import numpy as np
import transformers
import sys
from os.path import join
import datasets
from typing import Dict, List
from dict_hash import sha256
import joblib
import random
import numpy as np
import sklearn.tree
from sklearn.metrics import accuracy_score
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.preprocessing import OneHotEncoder
import imodelsx.treeprompt.stump
import imodelsx.llm
class TreePromptClassifier(BaseEstimator, ClassifierMixin):
def __init__(
self,
checkpoint: str,
prompts: List[str],
verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."},
tree_kwargs: Dict = {"max_leaf_nodes": 3},
batch_size: int = 1,
prompt_template: str = "{example}{prompt}",
cache_prompt_features_dir=join("cache_prompt_features"),
cache_key_values: bool = False,
device=None,
verbose: bool = True,
random_state: int = 42,
):
'''
Params
------
prompt_template: str
template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example}
or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output:
cache_key_values: bool
Whether to cache key values (only possible when prompt does not start with {example})
'''
self.checkpoint = checkpoint
self.prompts = prompts
self.verbalizer = verbalizer
self.tree_kwargs = tree_kwargs
self.batch_size = batch_size
self.prompt_template = prompt_template
self.cache_prompt_features_dir = cache_prompt_features_dir
self.cache_key_values = cache_key_values
self.device = device
self.verbose = verbose
self.random_state = random_state
def fit(self, X, y):
transformers.set_seed(self.random_state)
self.classes_ = np.unique(y)
self.n_classes_ = len(self.classes_)
assert len(self.classes_) == len(self.verbalizer)
# calculate prompt features
prompt_features = self._calc_prompt_features(X, self.prompts)
self.prompt_accs_ = [
accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts))
]
# apply one-hot encoding to features
if len(np.unique(y)) > 3:
print("Converting to one-hot")
self.enc_ = OneHotEncoder(handle_unknown="ignore")
prompt_features = self.enc_.fit_transform(prompt_features)
self.feature_names_ = self.enc_.get_feature_names_out(self.prompts)
else:
self.feature_names_ = self.prompts
# train decision tree
self.clf_ = sklearn.tree.DecisionTreeClassifier(
**self.tree_kwargs,
random_state=self.random_state,
)
self.clf_.fit(prompt_features, y)
self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[
1:
] # remove first element which is -2
return self
def _calc_prompt_features(self, X, prompts):
prompt_features = np.zeros((len(X), len(prompts)))
llm = imodelsx.llm.get_llm(self.checkpoint).model_
if self.device is not None:
llm = llm.to(self.device)
stump = None
for i, prompt in enumerate(prompts):
print(f"Prompt {i}: {prompt}")
loaded_from_cache = False
if self.cache_prompt_features_dir is not None:
os.makedirs(self.cache_prompt_features_dir, exist_ok=True)
args_dict_cache = {"prompt": prompt,
"X_len": len(X), "ex0": X[0]}
save_dir_unique_hash = sha256(args_dict_cache)
cache_file = join(
self.cache_prompt_features_dir, f"{save_dir_unique_hash}.pkl"
)
# load from cache if possible
if os.path.exists(cache_file):
print("loading from cache!")
try:
prompt_features_i = joblib.load(cache_file)
loaded_from_cache = True
except:
pass
if not loaded_from_cache:
if stump is None:
stump = imodelsx.treeprompt.stump.PromptStump(
model=llm,
checkpoint=self.checkpoint,
verbalizer=self.verbalizer,
batch_size=self.batch_size,
prompt_template=self.prompt_template,
cache_key_values=self.cache_key_values,
)
# calculate prompt_features
def _calc_features_single_prompt(
X, stump, prompt, past_key_values=None
):
"""Calculate features with a single prompt (results get cached)
preds: np.ndarray[int] of shape (X.shape[0],)
If multiclass, each int takes value 0, 1, ..., n_classes - 1 based on the verbalizer
"""
stump.prompt = prompt
if past_key_values is not None:
preds = stump.predict_with_cache(X, past_key_values)
else:
preds = stump.predict(X)
return preds
past_key_values = None
if self.cache_key_values:
stump.prompt = prompt
past_key_values = stump.calc_key_values(X)
prompt_features_i = _calc_features_single_prompt(
X, stump, prompt, past_key_values=past_key_values
)
if self.cache_prompt_features_dir is not None:
joblib.dump(prompt_features_i, cache_file)
# save prompt features
prompt_features[:, i] = prompt_features_i
return prompt_features
def predict_proba(self, X):
# extract prompt features
prompt_features = np.zeros((len(X), len(self.prompts)))
prompt_features_relevant = self._calc_prompt_features(
X, np.array(self.prompts)[self.prompts_idxs_kept]
)
prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant
# apply one-hot encoding to features
if hasattr(self, "enc_"):
X = self.enc_.transform(prompt_features)
# predict
return self.clf_.predict_proba(prompt_features)
def predict(self, X):
return self.predict_proba(X).argmax(axis=1)
Classes
class TreePromptClassifier (checkpoint: str, prompts: List[str], verbalizer: Dict[int, str] = {0: ' Negative.', 1: ' Positive.'}, tree_kwargs: Dict = {'max_leaf_nodes': 3}, batch_size: int = 1, prompt_template: str = '{example}{prompt}', cache_prompt_features_dir='cache_prompt_features', cache_key_values: bool = False, device=None, verbose: bool = True, random_state: int = 42)
-
Base class for all estimators in scikit-learn.
Inheriting from this class provides default implementations of:
- setting and getting parameters used by
GridSearchCV
and friends; - textual and HTML representation displayed in terminals and IDEs;
- estimator serialization;
- parameters validation;
- data validation;
- feature names validation.
Read more in the :ref:
User Guide <rolling_your_own_estimator>
.Notes
All estimators should specify all the parameters that can be set at the class level in their
__init__
as explicit keyword arguments (no*args
or**kwargs
).Examples
>>> import numpy as np >>> from sklearn.base import BaseEstimator >>> class MyEstimator(BaseEstimator): ... def __init__(self, *, param=1): ... self.param = param ... def fit(self, X, y=None): ... self.is_fitted_ = True ... return self ... def predict(self, X): ... return np.full(shape=X.shape[0], fill_value=self.param) >>> estimator = MyEstimator(param=2) >>> estimator.get_params() {'param': 2} >>> X = np.array([[1, 2], [2, 3], [3, 4]]) >>> y = np.array([1, 0, 1]) >>> estimator.fit(X, y).predict(X) array([2, 2, 2]) >>> estimator.set_params(param=3).fit(X, y).predict(X) array([3, 3, 3])
Params
prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example})
Expand source code
class TreePromptClassifier(BaseEstimator, ClassifierMixin): def __init__( self, checkpoint: str, prompts: List[str], verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."}, tree_kwargs: Dict = {"max_leaf_nodes": 3}, batch_size: int = 1, prompt_template: str = "{example}{prompt}", cache_prompt_features_dir=join("cache_prompt_features"), cache_key_values: bool = False, device=None, verbose: bool = True, random_state: int = 42, ): ''' Params ------ prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example}) ''' self.checkpoint = checkpoint self.prompts = prompts self.verbalizer = verbalizer self.tree_kwargs = tree_kwargs self.batch_size = batch_size self.prompt_template = prompt_template self.cache_prompt_features_dir = cache_prompt_features_dir self.cache_key_values = cache_key_values self.device = device self.verbose = verbose self.random_state = random_state def fit(self, X, y): transformers.set_seed(self.random_state) self.classes_ = np.unique(y) self.n_classes_ = len(self.classes_) assert len(self.classes_) == len(self.verbalizer) # calculate prompt features prompt_features = self._calc_prompt_features(X, self.prompts) self.prompt_accs_ = [ accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts)) ] # apply one-hot encoding to features if len(np.unique(y)) > 3: print("Converting to one-hot") self.enc_ = OneHotEncoder(handle_unknown="ignore") prompt_features = self.enc_.fit_transform(prompt_features) self.feature_names_ = self.enc_.get_feature_names_out(self.prompts) else: self.feature_names_ = self.prompts # train decision tree self.clf_ = sklearn.tree.DecisionTreeClassifier( **self.tree_kwargs, random_state=self.random_state, ) self.clf_.fit(prompt_features, y) self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[ 1: ] # remove first element which is -2 return self def _calc_prompt_features(self, X, prompts): prompt_features = np.zeros((len(X), len(prompts))) llm = imodelsx.llm.get_llm(self.checkpoint).model_ if self.device is not None: llm = llm.to(self.device) stump = None for i, prompt in enumerate(prompts): print(f"Prompt {i}: {prompt}") loaded_from_cache = False if self.cache_prompt_features_dir is not None: os.makedirs(self.cache_prompt_features_dir, exist_ok=True) args_dict_cache = {"prompt": prompt, "X_len": len(X), "ex0": X[0]} save_dir_unique_hash = sha256(args_dict_cache) cache_file = join( self.cache_prompt_features_dir, f"{save_dir_unique_hash}.pkl" ) # load from cache if possible if os.path.exists(cache_file): print("loading from cache!") try: prompt_features_i = joblib.load(cache_file) loaded_from_cache = True except: pass if not loaded_from_cache: if stump is None: stump = imodelsx.treeprompt.stump.PromptStump( model=llm, checkpoint=self.checkpoint, verbalizer=self.verbalizer, batch_size=self.batch_size, prompt_template=self.prompt_template, cache_key_values=self.cache_key_values, ) # calculate prompt_features def _calc_features_single_prompt( X, stump, prompt, past_key_values=None ): """Calculate features with a single prompt (results get cached) preds: np.ndarray[int] of shape (X.shape[0],) If multiclass, each int takes value 0, 1, ..., n_classes - 1 based on the verbalizer """ stump.prompt = prompt if past_key_values is not None: preds = stump.predict_with_cache(X, past_key_values) else: preds = stump.predict(X) return preds past_key_values = None if self.cache_key_values: stump.prompt = prompt past_key_values = stump.calc_key_values(X) prompt_features_i = _calc_features_single_prompt( X, stump, prompt, past_key_values=past_key_values ) if self.cache_prompt_features_dir is not None: joblib.dump(prompt_features_i, cache_file) # save prompt features prompt_features[:, i] = prompt_features_i return prompt_features def predict_proba(self, X): # extract prompt features prompt_features = np.zeros((len(X), len(self.prompts))) prompt_features_relevant = self._calc_prompt_features( X, np.array(self.prompts)[self.prompts_idxs_kept] ) prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant # apply one-hot encoding to features if hasattr(self, "enc_"): X = self.enc_.transform(prompt_features) # predict return self.clf_.predict_proba(prompt_features) def predict(self, X): return self.predict_proba(X).argmax(axis=1)
Ancestors
- sklearn.base.BaseEstimator
- sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
- sklearn.utils._metadata_requests._MetadataRequester
- sklearn.base.ClassifierMixin
Methods
def fit(self, X, y)
-
Expand source code
def fit(self, X, y): transformers.set_seed(self.random_state) self.classes_ = np.unique(y) self.n_classes_ = len(self.classes_) assert len(self.classes_) == len(self.verbalizer) # calculate prompt features prompt_features = self._calc_prompt_features(X, self.prompts) self.prompt_accs_ = [ accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts)) ] # apply one-hot encoding to features if len(np.unique(y)) > 3: print("Converting to one-hot") self.enc_ = OneHotEncoder(handle_unknown="ignore") prompt_features = self.enc_.fit_transform(prompt_features) self.feature_names_ = self.enc_.get_feature_names_out(self.prompts) else: self.feature_names_ = self.prompts # train decision tree self.clf_ = sklearn.tree.DecisionTreeClassifier( **self.tree_kwargs, random_state=self.random_state, ) self.clf_.fit(prompt_features, y) self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[ 1: ] # remove first element which is -2 return self
def predict(self, X)
-
Expand source code
def predict(self, X): return self.predict_proba(X).argmax(axis=1)
def predict_proba(self, X)
-
Expand source code
def predict_proba(self, X): # extract prompt features prompt_features = np.zeros((len(X), len(self.prompts))) prompt_features_relevant = self._calc_prompt_features( X, np.array(self.prompts)[self.prompts_idxs_kept] ) prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant # apply one-hot encoding to features if hasattr(self, "enc_"): X = self.enc_.transform(prompt_features) # predict return self.clf_.predict_proba(prompt_features)
def set_score_request(self: TreePromptClassifier, *, sample_weight: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> TreePromptClassifier
-
Request metadata passed to the
score
method.Note that this method is only relevant if
enable_metadata_routing=True
(see :func:sklearn.set_config
). Please see :ref:User Guide <metadata_routing>
on how the routing mechanism works.The options for each parameter are:
-
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided. -
False
: metadata is not requested and the meta-estimator will not pass it toscore
. -
None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it. -
str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version: 1.3
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:
~sklearn.pipeline.Pipeline
. Otherwise it has no effect.Parameters
sample_weight
:str, True, False,
orNone
, default=sklearn.utils.metadata_routing.UNCHANGED
- Metadata routing for
sample_weight
parameter inscore
.
Returns
self
:object
- The updated object.
Expand source code
def func(**kw): """Updates the request for provided parameters This docstring is overwritten below. See REQUESTER_DOC for expected functionality """ if not _routing_enabled(): raise RuntimeError( "This method is only available when metadata routing is enabled." " You can enable it using" " sklearn.set_config(enable_metadata_routing=True)." ) if self.validate_keys and (set(kw) - set(self.keys)): raise TypeError( f"Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments" f" are: {set(self.keys)}" ) requests = instance._get_metadata_request() method_metadata_request = getattr(requests, self.name) for prop, alias in kw.items(): if alias is not UNCHANGED: method_metadata_request.add_request(param=prop, alias=alias) instance._metadata_request = requests return instance
-
- setting and getting parameters used by