Module imodelsx.treeprompt.treeprompt

Classes

class TreePromptClassifier (checkpoint: str,
prompts: List[str],
verbalizer: Dict[int, str] = {0: ' Negative.', 1: ' Positive.'},
tree_kwargs: Dict = {'max_leaf_nodes': 3},
batch_size: int = 1,
prompt_template: str = '{example}{prompt}',
cache_prompt_features_dir='cache_prompt_features',
cache_key_values: bool = False,
device=None,
verbose: bool = True,
random_state: int = 42)
Expand source code
class TreePromptClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        checkpoint: str,
        prompts: List[str],
        verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."},
        tree_kwargs: Dict = {"max_leaf_nodes": 3},
        batch_size: int = 1,
        prompt_template: str = "{example}{prompt}",
        cache_prompt_features_dir=join("cache_prompt_features"),
        cache_key_values: bool = False,
        device=None,
        verbose: bool = True,
        random_state: int = 42,
    ):
        '''
        Params
        ------
        prompt_template: str
            template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example}
             or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output:
        cache_key_values: bool
            Whether to cache key values (only possible when prompt does not start with {example})
        '''
        self.checkpoint = checkpoint
        self.prompts = prompts
        self.verbalizer = verbalizer
        self.tree_kwargs = tree_kwargs
        self.batch_size = batch_size
        self.prompt_template = prompt_template
        self.cache_prompt_features_dir = cache_prompt_features_dir
        self.cache_key_values = cache_key_values
        self.device = device
        self.verbose = verbose
        self.random_state = random_state

    def fit(self, X, y):
        transformers.set_seed(self.random_state)
        self.classes_ = np.unique(y)
        self.n_classes_ = len(self.classes_)
        assert len(self.classes_) == len(self.verbalizer)

        # calculate prompt features
        prompt_features = self._calc_prompt_features(X, self.prompts)
        self.prompt_accs_ = [
            accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts))
        ]

        # apply one-hot encoding to features
        if len(np.unique(y)) > 3:
            print("Converting to one-hot")
            self.enc_ = OneHotEncoder(handle_unknown="ignore")
            prompt_features = self.enc_.fit_transform(prompt_features)
            self.feature_names_ = self.enc_.get_feature_names_out(self.prompts)
        else:
            self.feature_names_ = self.prompts

        # train decision tree
        self.clf_ = sklearn.tree.DecisionTreeClassifier(
            **self.tree_kwargs,
            random_state=self.random_state,
        )
        self.clf_.fit(prompt_features, y)
        self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[
            1:
        ]  # remove first element which is -2

        return self

    def _calc_prompt_features(self, X, prompts):
        prompt_features = np.zeros((len(X), len(prompts)))
        llm = imodelsx.llm.get_llm(self.checkpoint).model_
        if self.device is not None:
            llm = llm.to(self.device)
        stump = None

        for i, prompt in enumerate(prompts):
            print(f"Prompt {i}: {prompt}")

            loaded_from_cache = False
            if self.cache_prompt_features_dir is not None:
                os.makedirs(self.cache_prompt_features_dir, exist_ok=True)
                args_dict_cache = {"prompt": prompt,
                                   "X_len": len(X), "ex0": X[0]}
                save_dir_unique_hash = sha256(args_dict_cache)
                cache_file = join(
                    self.cache_prompt_features_dir, f"{save_dir_unique_hash}.pkl"
                )

                # load from cache if possible
                if os.path.exists(cache_file):
                    print("loading from cache!")
                    try:
                        prompt_features_i = joblib.load(cache_file)
                        loaded_from_cache = True
                    except:
                        pass

            if not loaded_from_cache:
                if stump is None:
                    stump = imodelsx.treeprompt.stump.PromptStump(
                        model=llm,
                        checkpoint=self.checkpoint,
                        verbalizer=self.verbalizer,
                        batch_size=self.batch_size,
                        prompt_template=self.prompt_template,
                        cache_key_values=self.cache_key_values,
                    )

                # calculate prompt_features
                def _calc_features_single_prompt(
                    X, stump, prompt, past_key_values=None
                ):
                    """Calculate features with a single prompt (results get cached)
                    preds: np.ndarray[int] of shape (X.shape[0],)
                        If multiclass, each int takes value 0, 1, ..., n_classes - 1 based on the verbalizer
                    """
                    stump.prompt = prompt
                    if past_key_values is not None:
                        preds = stump.predict_with_cache(X, past_key_values)
                    else:
                        preds = stump.predict(X)
                    return preds

                past_key_values = None
                if self.cache_key_values:
                    stump.prompt = prompt
                    past_key_values = stump.calc_key_values(X)
                prompt_features_i = _calc_features_single_prompt(
                    X, stump, prompt, past_key_values=past_key_values
                )
                if self.cache_prompt_features_dir is not None:
                    joblib.dump(prompt_features_i, cache_file)

            # save prompt features
            prompt_features[:, i] = prompt_features_i

        return prompt_features

    def predict_proba(self, X):
        # extract prompt features
        prompt_features = np.zeros((len(X), len(self.prompts)))
        prompt_features_relevant = self._calc_prompt_features(
            X, np.array(self.prompts)[self.prompts_idxs_kept]
        )
        prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant

        # apply one-hot encoding to features
        if hasattr(self, "enc_"):
            X = self.enc_.transform(prompt_features)

        # predict
        return self.clf_.predict_proba(prompt_features)

    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)

Base class for all estimators in scikit-learn.

Inheriting from this class provides default implementations of:

  • setting and getting parameters used by GridSearchCV and friends;
  • textual and HTML representation displayed in terminals and IDEs;
  • estimator serialization;
  • parameters validation;
  • data validation;
  • feature names validation.

Read more in the :ref:User Guide <rolling_your_own_estimator>.

Notes

All estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments (no *args or **kwargs).

Examples

>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, *, param=1):
...         self.param = param
...     def fit(self, X, y=None):
...         self.is_fitted_ = True
...         return self
...     def predict(self, X):
...         return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])

Params

prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example})

Ancestors

  • sklearn.base.BaseEstimator
  • sklearn.utils._repr_html.base.ReprHTMLMixin
  • sklearn.utils._repr_html.base._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester
  • sklearn.base.ClassifierMixin

Methods

def fit(self, X, y)
Expand source code
def fit(self, X, y):
    transformers.set_seed(self.random_state)
    self.classes_ = np.unique(y)
    self.n_classes_ = len(self.classes_)
    assert len(self.classes_) == len(self.verbalizer)

    # calculate prompt features
    prompt_features = self._calc_prompt_features(X, self.prompts)
    self.prompt_accs_ = [
        accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts))
    ]

    # apply one-hot encoding to features
    if len(np.unique(y)) > 3:
        print("Converting to one-hot")
        self.enc_ = OneHotEncoder(handle_unknown="ignore")
        prompt_features = self.enc_.fit_transform(prompt_features)
        self.feature_names_ = self.enc_.get_feature_names_out(self.prompts)
    else:
        self.feature_names_ = self.prompts

    # train decision tree
    self.clf_ = sklearn.tree.DecisionTreeClassifier(
        **self.tree_kwargs,
        random_state=self.random_state,
    )
    self.clf_.fit(prompt_features, y)
    self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[
        1:
    ]  # remove first element which is -2

    return self
def predict(self, X)
Expand source code
def predict(self, X):
    return self.predict_proba(X).argmax(axis=1)
def predict_proba(self, X)
Expand source code
def predict_proba(self, X):
    # extract prompt features
    prompt_features = np.zeros((len(X), len(self.prompts)))
    prompt_features_relevant = self._calc_prompt_features(
        X, np.array(self.prompts)[self.prompts_idxs_kept]
    )
    prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant

    # apply one-hot encoding to features
    if hasattr(self, "enc_"):
        X = self.enc_.transform(prompt_features)

    # predict
    return self.clf_.predict_proba(prompt_features)
def set_score_request(self: TreePromptClassifier,
*,
sample_weight: bool | str | None = '$UNCHANGED$') ‑> TreePromptClassifier
Expand source code
def func(*args, **kw):
    """Updates the `_metadata_request` attribute of the consumer (`instance`)
    for the parameters provided as `**kw`.

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality.
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. "
            f"Accepted arguments are: {set(self.keys)}"
        )

    # This makes it possible to use the decorated method as an unbound method,
    # for instance when monkeypatching.
    # https://github.com/scikit-learn/scikit-learn/issues/28632
    if instance is None:
        _instance = args[0]
        args = args[1:]
    else:
        _instance = instance

    # Replicating python's behavior when positional args are given other than
    # `self`, and `self` is only allowed if this method is unbound.
    if args:
        raise TypeError(
            f"set_{self.name}_request() takes 0 positional argument but"
            f" {len(args)} were given"
        )

    requests = _instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    _instance._metadata_request = requests

    return _instance

Configure whether metadata should be requested to be passed to the score method.

    Note that this method is only relevant when this estimator is used as a
    sub-estimator within a :term:`meta-estimator` and metadata routing is enabled
    with ``enable_metadata_routing=True`` (see :func:<code>sklearn.set\_config</code>).
    Please check the :ref:`User Guide <metadata_routing>` on how the routing
    mechanism works.

    The options for each parameter are:

    - <code>True</code>: metadata is requested, and passed to <code>score</code> if provided. The request is ignored if metadata is not provided.

    - <code>False</code>: metadata is not requested and the meta-estimator will not pass it to <code>score</code>.

    - <code>None</code>: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

    - <code>str</code>: metadata should be passed to the meta-estimator with this given alias instead of the original name.

    The default (<code>sklearn.utils.metadata\_routing.UNCHANGED</code>) retains the
    existing request. This allows you to change the request for some
    parameters and not others.

    !!! versionadded "Added in version:&ensp;1.3"



    Parameters
    ----------
    sample_weight : str, True, False, or None,                     default=sklearn.utils.metadata_routing.UNCHANGED
        Metadata routing for <code>sample\_weight</code> parameter in <code>score</code>.

    Returns
    -------
    self : object
        The updated object.