Module imodelsx.treeprompt.treeprompt
Classes
class TreePromptClassifier (checkpoint: str,
prompts: List[str],
verbalizer: Dict[int, str] = {0: ' Negative.', 1: ' Positive.'},
tree_kwargs: Dict = {'max_leaf_nodes': 3},
batch_size: int = 1,
prompt_template: str = '{example}{prompt}',
cache_prompt_features_dir='cache_prompt_features',
cache_key_values: bool = False,
device=None,
verbose: bool = True,
random_state: int = 42)-
Expand source code
class TreePromptClassifier(BaseEstimator, ClassifierMixin): def __init__( self, checkpoint: str, prompts: List[str], verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."}, tree_kwargs: Dict = {"max_leaf_nodes": 3}, batch_size: int = 1, prompt_template: str = "{example}{prompt}", cache_prompt_features_dir=join("cache_prompt_features"), cache_key_values: bool = False, device=None, verbose: bool = True, random_state: int = 42, ): ''' Params ------ prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example}) ''' self.checkpoint = checkpoint self.prompts = prompts self.verbalizer = verbalizer self.tree_kwargs = tree_kwargs self.batch_size = batch_size self.prompt_template = prompt_template self.cache_prompt_features_dir = cache_prompt_features_dir self.cache_key_values = cache_key_values self.device = device self.verbose = verbose self.random_state = random_state def fit(self, X, y): transformers.set_seed(self.random_state) self.classes_ = np.unique(y) self.n_classes_ = len(self.classes_) assert len(self.classes_) == len(self.verbalizer) # calculate prompt features prompt_features = self._calc_prompt_features(X, self.prompts) self.prompt_accs_ = [ accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts)) ] # apply one-hot encoding to features if len(np.unique(y)) > 3: print("Converting to one-hot") self.enc_ = OneHotEncoder(handle_unknown="ignore") prompt_features = self.enc_.fit_transform(prompt_features) self.feature_names_ = self.enc_.get_feature_names_out(self.prompts) else: self.feature_names_ = self.prompts # train decision tree self.clf_ = sklearn.tree.DecisionTreeClassifier( **self.tree_kwargs, random_state=self.random_state, ) self.clf_.fit(prompt_features, y) self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[ 1: ] # remove first element which is -2 return self def _calc_prompt_features(self, X, prompts): prompt_features = np.zeros((len(X), len(prompts))) llm = imodelsx.llm.get_llm(self.checkpoint).model_ if self.device is not None: llm = llm.to(self.device) stump = None for i, prompt in enumerate(prompts): print(f"Prompt {i}: {prompt}") loaded_from_cache = False if self.cache_prompt_features_dir is not None: os.makedirs(self.cache_prompt_features_dir, exist_ok=True) args_dict_cache = {"prompt": prompt, "X_len": len(X), "ex0": X[0]} save_dir_unique_hash = sha256(args_dict_cache) cache_file = join( self.cache_prompt_features_dir, f"{save_dir_unique_hash}.pkl" ) # load from cache if possible if os.path.exists(cache_file): print("loading from cache!") try: prompt_features_i = joblib.load(cache_file) loaded_from_cache = True except: pass if not loaded_from_cache: if stump is None: stump = imodelsx.treeprompt.stump.PromptStump( model=llm, checkpoint=self.checkpoint, verbalizer=self.verbalizer, batch_size=self.batch_size, prompt_template=self.prompt_template, cache_key_values=self.cache_key_values, ) # calculate prompt_features def _calc_features_single_prompt( X, stump, prompt, past_key_values=None ): """Calculate features with a single prompt (results get cached) preds: np.ndarray[int] of shape (X.shape[0],) If multiclass, each int takes value 0, 1, ..., n_classes - 1 based on the verbalizer """ stump.prompt = prompt if past_key_values is not None: preds = stump.predict_with_cache(X, past_key_values) else: preds = stump.predict(X) return preds past_key_values = None if self.cache_key_values: stump.prompt = prompt past_key_values = stump.calc_key_values(X) prompt_features_i = _calc_features_single_prompt( X, stump, prompt, past_key_values=past_key_values ) if self.cache_prompt_features_dir is not None: joblib.dump(prompt_features_i, cache_file) # save prompt features prompt_features[:, i] = prompt_features_i return prompt_features def predict_proba(self, X): # extract prompt features prompt_features = np.zeros((len(X), len(self.prompts))) prompt_features_relevant = self._calc_prompt_features( X, np.array(self.prompts)[self.prompts_idxs_kept] ) prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant # apply one-hot encoding to features if hasattr(self, "enc_"): X = self.enc_.transform(prompt_features) # predict return self.clf_.predict_proba(prompt_features) def predict(self, X): return self.predict_proba(X).argmax(axis=1)
Base class for all estimators in scikit-learn.
Inheriting from this class provides default implementations of:
- setting and getting parameters used by
GridSearchCV
and friends; - textual and HTML representation displayed in terminals and IDEs;
- estimator serialization;
- parameters validation;
- data validation;
- feature names validation.
Read more in the :ref:
User Guide <rolling_your_own_estimator>
.Notes
All estimators should specify all the parameters that can be set at the class level in their
__init__
as explicit keyword arguments (no*args
or**kwargs
).Examples
>>> import numpy as np >>> from sklearn.base import BaseEstimator >>> class MyEstimator(BaseEstimator): ... def __init__(self, *, param=1): ... self.param = param ... def fit(self, X, y=None): ... self.is_fitted_ = True ... return self ... def predict(self, X): ... return np.full(shape=X.shape[0], fill_value=self.param) >>> estimator = MyEstimator(param=2) >>> estimator.get_params() {'param': 2} >>> X = np.array([[1, 2], [2, 3], [3, 4]]) >>> y = np.array([1, 0, 1]) >>> estimator.fit(X, y).predict(X) array([2, 2, 2]) >>> estimator.set_params(param=3).fit(X, y).predict(X) array([3, 3, 3])
Params
prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example})
Ancestors
- sklearn.base.BaseEstimator
- sklearn.utils._repr_html.base.ReprHTMLMixin
- sklearn.utils._repr_html.base._HTMLDocumentationLinkMixin
- sklearn.utils._metadata_requests._MetadataRequester
- sklearn.base.ClassifierMixin
Methods
def fit(self, X, y)
-
Expand source code
def fit(self, X, y): transformers.set_seed(self.random_state) self.classes_ = np.unique(y) self.n_classes_ = len(self.classes_) assert len(self.classes_) == len(self.verbalizer) # calculate prompt features prompt_features = self._calc_prompt_features(X, self.prompts) self.prompt_accs_ = [ accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts)) ] # apply one-hot encoding to features if len(np.unique(y)) > 3: print("Converting to one-hot") self.enc_ = OneHotEncoder(handle_unknown="ignore") prompt_features = self.enc_.fit_transform(prompt_features) self.feature_names_ = self.enc_.get_feature_names_out(self.prompts) else: self.feature_names_ = self.prompts # train decision tree self.clf_ = sklearn.tree.DecisionTreeClassifier( **self.tree_kwargs, random_state=self.random_state, ) self.clf_.fit(prompt_features, y) self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[ 1: ] # remove first element which is -2 return self
def predict(self, X)
-
Expand source code
def predict(self, X): return self.predict_proba(X).argmax(axis=1)
def predict_proba(self, X)
-
Expand source code
def predict_proba(self, X): # extract prompt features prompt_features = np.zeros((len(X), len(self.prompts))) prompt_features_relevant = self._calc_prompt_features( X, np.array(self.prompts)[self.prompts_idxs_kept] ) prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant # apply one-hot encoding to features if hasattr(self, "enc_"): X = self.enc_.transform(prompt_features) # predict return self.clf_.predict_proba(prompt_features)
def set_score_request(self: TreePromptClassifier,
*,
sample_weight: bool | str | None = '$UNCHANGED$') ‑> TreePromptClassifier-
Expand source code
def func(*args, **kw): """Updates the `_metadata_request` attribute of the consumer (`instance`) for the parameters provided as `**kw`. This docstring is overwritten below. See REQUESTER_DOC for expected functionality. """ if not _routing_enabled(): raise RuntimeError( "This method is only available when metadata routing is enabled." " You can enable it using" " sklearn.set_config(enable_metadata_routing=True)." ) if self.validate_keys and (set(kw) - set(self.keys)): raise TypeError( f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. " f"Accepted arguments are: {set(self.keys)}" ) # This makes it possible to use the decorated method as an unbound method, # for instance when monkeypatching. # https://github.com/scikit-learn/scikit-learn/issues/28632 if instance is None: _instance = args[0] args = args[1:] else: _instance = instance # Replicating python's behavior when positional args are given other than # `self`, and `self` is only allowed if this method is unbound. if args: raise TypeError( f"set_{self.name}_request() takes 0 positional argument but" f" {len(args)} were given" ) requests = _instance._get_metadata_request() method_metadata_request = getattr(requests, self.name) for prop, alias in kw.items(): if alias is not UNCHANGED: method_metadata_request.add_request(param=prop, alias=alias) _instance._metadata_request = requests return _instance
Configure whether metadata should be requested to be passed to the
score
method.Note that this method is only relevant when this estimator is used as a sub-estimator within a :term:`meta-estimator` and metadata routing is enabled with ``enable_metadata_routing=True`` (see :func:<code>sklearn.set\_config</code>). Please check the :ref:`User Guide <metadata_routing>` on how the routing mechanism works. The options for each parameter are: - <code>True</code>: metadata is requested, and passed to <code>score</code> if provided. The request is ignored if metadata is not provided. - <code>False</code>: metadata is not requested and the meta-estimator will not pass it to <code>score</code>. - <code>None</code>: metadata is not requested, and the meta-estimator will raise an error if the user provides it. - <code>str</code>: metadata should be passed to the meta-estimator with this given alias instead of the original name. The default (<code>sklearn.utils.metadata\_routing.UNCHANGED</code>) retains the existing request. This allows you to change the request for some parameters and not others. !!! versionadded "Added in version: 1.3" Parameters ---------- sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for <code>sample\_weight</code> parameter in <code>score</code>. Returns ------- self : object The updated object.
- setting and getting parameters used by