Module imodelsx.treeprompt.treeprompt

Expand source code
import os
import pandas as pd
import numpy as np
import transformers
import sys
from os.path import join
import datasets
from typing import Dict, List
from dict_hash import sha256
import joblib
import random
import numpy as np
import sklearn.tree
from sklearn.metrics import accuracy_score
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.preprocessing import OneHotEncoder

import imodelsx.treeprompt.stump
import imodelsx.llm


class TreePromptClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        checkpoint: str,
        prompts: List[str],
        verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."},
        tree_kwargs: Dict = {"max_leaf_nodes": 3},
        batch_size: int = 1,
        prompt_template: str = "{example}{prompt}",
        cache_prompt_features_dir=join("cache_prompt_features"),
        cache_key_values: bool = False,
        device=None,
        verbose: bool = True,
        random_state: int = 42,
    ):
        '''
        Params
        ------
        prompt_template: str
            template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example}
             or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output:
        cache_key_values: bool
            Whether to cache key values (only possible when prompt does not start with {example})
        '''
        self.checkpoint = checkpoint
        self.prompts = prompts
        self.verbalizer = verbalizer
        self.tree_kwargs = tree_kwargs
        self.batch_size = batch_size
        self.prompt_template = prompt_template
        self.cache_prompt_features_dir = cache_prompt_features_dir
        self.cache_key_values = cache_key_values
        self.device = device
        self.verbose = verbose
        self.random_state = random_state

    def fit(self, X, y):
        transformers.set_seed(self.random_state)
        self.classes_ = np.unique(y)
        self.n_classes_ = len(self.classes_)
        assert len(self.classes_) == len(self.verbalizer)

        # calculate prompt features
        prompt_features = self._calc_prompt_features(X, self.prompts)
        self.prompt_accs_ = [
            accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts))
        ]

        # apply one-hot encoding to features
        if len(np.unique(y)) > 3:
            print("Converting to one-hot")
            self.enc_ = OneHotEncoder(handle_unknown="ignore")
            prompt_features = self.enc_.fit_transform(prompt_features)
            self.feature_names_ = self.enc_.get_feature_names_out(self.prompts)
        else:
            self.feature_names_ = self.prompts

        # train decision tree
        self.clf_ = sklearn.tree.DecisionTreeClassifier(
            **self.tree_kwargs,
            random_state=self.random_state,
        )
        self.clf_.fit(prompt_features, y)
        self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[
            1:
        ]  # remove first element which is -2

        return self

    def _calc_prompt_features(self, X, prompts):
        prompt_features = np.zeros((len(X), len(prompts)))
        llm = imodelsx.llm.get_llm(self.checkpoint).model_
        if self.device is not None:
            llm = llm.to(self.device)
        stump = None

        for i, prompt in enumerate(prompts):
            print(f"Prompt {i}: {prompt}")

            loaded_from_cache = False
            if self.cache_prompt_features_dir is not None:
                os.makedirs(self.cache_prompt_features_dir, exist_ok=True)
                args_dict_cache = {"prompt": prompt,
                                   "X_len": len(X), "ex0": X[0]}
                save_dir_unique_hash = sha256(args_dict_cache)
                cache_file = join(
                    self.cache_prompt_features_dir, f"{save_dir_unique_hash}.pkl"
                )

                # load from cache if possible
                if os.path.exists(cache_file):
                    print("loading from cache!")
                    try:
                        prompt_features_i = joblib.load(cache_file)
                        loaded_from_cache = True
                    except:
                        pass

            if not loaded_from_cache:
                if stump is None:
                    stump = imodelsx.treeprompt.stump.PromptStump(
                        model=llm,
                        checkpoint=self.checkpoint,
                        verbalizer=self.verbalizer,
                        batch_size=self.batch_size,
                        prompt_template=self.prompt_template,
                        cache_key_values=self.cache_key_values,
                    )

                # calculate prompt_features
                def _calc_features_single_prompt(
                    X, stump, prompt, past_key_values=None
                ):
                    """Calculate features with a single prompt (results get cached)
                    preds: np.ndarray[int] of shape (X.shape[0],)
                        If multiclass, each int takes value 0, 1, ..., n_classes - 1 based on the verbalizer
                    """
                    stump.prompt = prompt
                    if past_key_values is not None:
                        preds = stump.predict_with_cache(X, past_key_values)
                    else:
                        preds = stump.predict(X)
                    return preds

                past_key_values = None
                if self.cache_key_values:
                    stump.prompt = prompt
                    past_key_values = stump.calc_key_values(X)
                prompt_features_i = _calc_features_single_prompt(
                    X, stump, prompt, past_key_values=past_key_values
                )
                if self.cache_prompt_features_dir is not None:
                    joblib.dump(prompt_features_i, cache_file)

            # save prompt features
            prompt_features[:, i] = prompt_features_i

        return prompt_features

    def predict_proba(self, X):
        # extract prompt features
        prompt_features = np.zeros((len(X), len(self.prompts)))
        prompt_features_relevant = self._calc_prompt_features(
            X, np.array(self.prompts)[self.prompts_idxs_kept]
        )
        prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant

        # apply one-hot encoding to features
        if hasattr(self, "enc_"):
            X = self.enc_.transform(prompt_features)

        # predict
        return self.clf_.predict_proba(prompt_features)

    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)

Classes

class TreePromptClassifier (checkpoint: str, prompts: List[str], verbalizer: Dict[int, str] = {0: ' Negative.', 1: ' Positive.'}, tree_kwargs: Dict = {'max_leaf_nodes': 3}, batch_size: int = 1, prompt_template: str = '{example}{prompt}', cache_prompt_features_dir='cache_prompt_features', cache_key_values: bool = False, device=None, verbose: bool = True, random_state: int = 42)

Base class for all estimators in scikit-learn.

Inheriting from this class provides default implementations of:

  • setting and getting parameters used by GridSearchCV and friends;
  • textual and HTML representation displayed in terminals and IDEs;
  • estimator serialization;
  • parameters validation;
  • data validation;
  • feature names validation.

Read more in the :ref:User Guide <rolling_your_own_estimator>.

Notes

All estimators should specify all the parameters that can be set at the class level in their __init__ as explicit keyword arguments (no *args or **kwargs).

Examples

>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, *, param=1):
...         self.param = param
...     def fit(self, X, y=None):
...         self.is_fitted_ = True
...         return self
...     def predict(self, X):
...         return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])

Params

prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example})

Expand source code
class TreePromptClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        checkpoint: str,
        prompts: List[str],
        verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."},
        tree_kwargs: Dict = {"max_leaf_nodes": 3},
        batch_size: int = 1,
        prompt_template: str = "{example}{prompt}",
        cache_prompt_features_dir=join("cache_prompt_features"),
        cache_key_values: bool = False,
        device=None,
        verbose: bool = True,
        random_state: int = 42,
    ):
        '''
        Params
        ------
        prompt_template: str
            template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example}
             or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output:
        cache_key_values: bool
            Whether to cache key values (only possible when prompt does not start with {example})
        '''
        self.checkpoint = checkpoint
        self.prompts = prompts
        self.verbalizer = verbalizer
        self.tree_kwargs = tree_kwargs
        self.batch_size = batch_size
        self.prompt_template = prompt_template
        self.cache_prompt_features_dir = cache_prompt_features_dir
        self.cache_key_values = cache_key_values
        self.device = device
        self.verbose = verbose
        self.random_state = random_state

    def fit(self, X, y):
        transformers.set_seed(self.random_state)
        self.classes_ = np.unique(y)
        self.n_classes_ = len(self.classes_)
        assert len(self.classes_) == len(self.verbalizer)

        # calculate prompt features
        prompt_features = self._calc_prompt_features(X, self.prompts)
        self.prompt_accs_ = [
            accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts))
        ]

        # apply one-hot encoding to features
        if len(np.unique(y)) > 3:
            print("Converting to one-hot")
            self.enc_ = OneHotEncoder(handle_unknown="ignore")
            prompt_features = self.enc_.fit_transform(prompt_features)
            self.feature_names_ = self.enc_.get_feature_names_out(self.prompts)
        else:
            self.feature_names_ = self.prompts

        # train decision tree
        self.clf_ = sklearn.tree.DecisionTreeClassifier(
            **self.tree_kwargs,
            random_state=self.random_state,
        )
        self.clf_.fit(prompt_features, y)
        self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[
            1:
        ]  # remove first element which is -2

        return self

    def _calc_prompt_features(self, X, prompts):
        prompt_features = np.zeros((len(X), len(prompts)))
        llm = imodelsx.llm.get_llm(self.checkpoint).model_
        if self.device is not None:
            llm = llm.to(self.device)
        stump = None

        for i, prompt in enumerate(prompts):
            print(f"Prompt {i}: {prompt}")

            loaded_from_cache = False
            if self.cache_prompt_features_dir is not None:
                os.makedirs(self.cache_prompt_features_dir, exist_ok=True)
                args_dict_cache = {"prompt": prompt,
                                   "X_len": len(X), "ex0": X[0]}
                save_dir_unique_hash = sha256(args_dict_cache)
                cache_file = join(
                    self.cache_prompt_features_dir, f"{save_dir_unique_hash}.pkl"
                )

                # load from cache if possible
                if os.path.exists(cache_file):
                    print("loading from cache!")
                    try:
                        prompt_features_i = joblib.load(cache_file)
                        loaded_from_cache = True
                    except:
                        pass

            if not loaded_from_cache:
                if stump is None:
                    stump = imodelsx.treeprompt.stump.PromptStump(
                        model=llm,
                        checkpoint=self.checkpoint,
                        verbalizer=self.verbalizer,
                        batch_size=self.batch_size,
                        prompt_template=self.prompt_template,
                        cache_key_values=self.cache_key_values,
                    )

                # calculate prompt_features
                def _calc_features_single_prompt(
                    X, stump, prompt, past_key_values=None
                ):
                    """Calculate features with a single prompt (results get cached)
                    preds: np.ndarray[int] of shape (X.shape[0],)
                        If multiclass, each int takes value 0, 1, ..., n_classes - 1 based on the verbalizer
                    """
                    stump.prompt = prompt
                    if past_key_values is not None:
                        preds = stump.predict_with_cache(X, past_key_values)
                    else:
                        preds = stump.predict(X)
                    return preds

                past_key_values = None
                if self.cache_key_values:
                    stump.prompt = prompt
                    past_key_values = stump.calc_key_values(X)
                prompt_features_i = _calc_features_single_prompt(
                    X, stump, prompt, past_key_values=past_key_values
                )
                if self.cache_prompt_features_dir is not None:
                    joblib.dump(prompt_features_i, cache_file)

            # save prompt features
            prompt_features[:, i] = prompt_features_i

        return prompt_features

    def predict_proba(self, X):
        # extract prompt features
        prompt_features = np.zeros((len(X), len(self.prompts)))
        prompt_features_relevant = self._calc_prompt_features(
            X, np.array(self.prompts)[self.prompts_idxs_kept]
        )
        prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant

        # apply one-hot encoding to features
        if hasattr(self, "enc_"):
            X = self.enc_.transform(prompt_features)

        # predict
        return self.clf_.predict_proba(prompt_features)

    def predict(self, X):
        return self.predict_proba(X).argmax(axis=1)

Ancestors

  • sklearn.base.BaseEstimator
  • sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin
  • sklearn.utils._metadata_requests._MetadataRequester
  • sklearn.base.ClassifierMixin

Methods

def fit(self, X, y)
Expand source code
def fit(self, X, y):
    transformers.set_seed(self.random_state)
    self.classes_ = np.unique(y)
    self.n_classes_ = len(self.classes_)
    assert len(self.classes_) == len(self.verbalizer)

    # calculate prompt features
    prompt_features = self._calc_prompt_features(X, self.prompts)
    self.prompt_accs_ = [
        accuracy_score(y, prompt_features[:, i]) for i in range(len(self.prompts))
    ]

    # apply one-hot encoding to features
    if len(np.unique(y)) > 3:
        print("Converting to one-hot")
        self.enc_ = OneHotEncoder(handle_unknown="ignore")
        prompt_features = self.enc_.fit_transform(prompt_features)
        self.feature_names_ = self.enc_.get_feature_names_out(self.prompts)
    else:
        self.feature_names_ = self.prompts

    # train decision tree
    self.clf_ = sklearn.tree.DecisionTreeClassifier(
        **self.tree_kwargs,
        random_state=self.random_state,
    )
    self.clf_.fit(prompt_features, y)
    self.prompts_idxs_kept = np.unique(self.clf_.tree_.feature)[
        1:
    ]  # remove first element which is -2

    return self
def predict(self, X)
Expand source code
def predict(self, X):
    return self.predict_proba(X).argmax(axis=1)
def predict_proba(self, X)
Expand source code
def predict_proba(self, X):
    # extract prompt features
    prompt_features = np.zeros((len(X), len(self.prompts)))
    prompt_features_relevant = self._calc_prompt_features(
        X, np.array(self.prompts)[self.prompts_idxs_kept]
    )
    prompt_features[:, self.prompts_idxs_kept] = prompt_features_relevant

    # apply one-hot encoding to features
    if hasattr(self, "enc_"):
        X = self.enc_.transform(prompt_features)

    # predict
    return self.clf_.predict_proba(prompt_features)
def set_score_request(self: TreePromptClassifier, *, sample_weight: Union[bool, ForwardRef(None), str] = '$UNCHANGED$') ‑> TreePromptClassifier

Request metadata passed to the score method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to score.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version: 1.3

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for sample_weight parameter in score.

Returns

self : object
The updated object.
Expand source code
def func(**kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments"
            f" are: {set(self.keys)}"
        )

    requests = instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    instance._metadata_request = requests

    return instance