Module imodelsx.treeprompt.stump

Expand source code
from typing import Dict, List

from abc import ABC, abstractmethod
import logging
import math
import random
import imodels
import numpy as np
from scipy.special import softmax
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
import torch.cuda
import tqdm
import imodelsx.llm
from transformers import AutoTokenizer, AutoModelForCausalLM


class PromptStump:
    def __init__(
        self,
        args=None,
        prompt: str = None,
        tokenizer=None,
        prompt_template: str = "{example}{prompt}",
        cache_key_values: bool = False,
        verbose: bool = True,
        model: AutoModelForCausalLM = None,
        checkpoint: str = "EleutherAI/gpt-j-6B",
        verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."},
        batch_size: int = 1,
    ):
        """Given a prompt, extract its outputs

        Params
        ------
        args: contains some parameters passed through namespace (can ignore these)
        prompt: str
            the prompt to use (optional)
        prompt_template: str
            template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example}
            or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output:
        cache_key_values: bool
            Whether to cache key values (only possible when prompt does not start with {example})
        checkpoint: str
            the underlying model used for prediction
        model: AutoModelForCausalLM
            if this is passed, will override checkpoint
        """
        if args is None:

            class placeholder:
                prompt_source = None
                template_data_demonstrations = None
                dataset_name = ""

            self.args = placeholder()
        else:
            self.args = args
        self.prompt = prompt
        self.prompt_template = prompt_template
        self.cache_key_values = cache_key_values
        self.verbose = verbose
        self.checkpoint = checkpoint
        self.model = model
        if tokenizer is None:
            self.tokenizer = imodelsx.llm.load_tokenizer(checkpoint)
        else:
            self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.verbalizer = verbalizer

        if self.verbose:
            logging.info(f"Loading model {self.checkpoint}")

    def predict(self, X_text: List[str]) -> np.ndarray[int]:
        preds_proba = self.predict_proba(X_text)
        return np.argmax(preds_proba, axis=1)

    def predict_with_cache(self, X_text: List[str], past_key_values) -> np.ndarray[int]:
        preds_proba = self.predict_proba_with_cache(X_text, past_key_values)
        return np.argmax(preds_proba, axis=1)

    def predict_proba(self, X_text: List[str]) -> np.ndarray[float]:
        target_strs = list(self.verbalizer.values())

        # only predict based on first token of output string
        target_token_ids = list(map(self._get_first_token_id, target_strs))
        assert len(set(target_token_ids)) == len(
            set(target_strs)
        ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"
        text_inputs = [self.prompt_template.format(
            example=x, prompt=self.prompt) for x in X_text]

        preds = self._get_logit_for_target_tokens_batched(
            text_inputs,
            target_token_ids,
            batch_size=self.batch_size,
        )
        assert preds.shape == (len(X_text), len(target_token_ids)), (
            "preds shape was"
            + str(preds.shape)
            + " but should have been "
            + str((len(X_text), len(target_token_ids)))
        )

        # return the class with the highest logit
        return softmax(preds, axis=1)

    def predict_proba_with_cache(
        self, X_text: List[str], past_key_values
    ) -> np.ndarray[float]:
        target_strs = list(self.verbalizer.values())

        # only predict based on first token of output string
        target_token_ids = list(map(self._get_first_token_id, target_strs))
        assert len(set(target_token_ids)) == len(
            set(target_strs)
        ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"

        text_inputs = [self.prompt_template.format(
            example=x, prompt=self.prompt) for x in X_text]

        preds = self._get_logit_for_target_tokens_batched_with_cache(
            past_key_values,
            text_inputs,
            target_token_ids,
            batch_size=self.batch_size,
        )

        assert preds.shape == (len(X_text), len(target_token_ids)), (
            "preds shape was"
            + str(preds.shape)
            + " but should have been "
            + str((len(X_text), len(target_token_ids)))
        )

        # return the class with the highest logit
        return softmax(preds, axis=1)

    def calc_key_values(self, X_text: List[str]):
        # only predict based on first token of output string
        self.tokenizer.truncation_side = "left"
        self.tokenizer.padding = True

        self.tokenizer.pad_token = self.tokenizer.eos_token

        p = self.prompt
        template = self.args.template_data_demonstrations
        if self.args.dataset_name.startswith("knnp__"):
            max_len_verb = max(
                len(self.tokenizer.encode(v)) for v in self.verbalizer.values()
            )
            max_len_input = (
                max_len_verb
                + max(len(self.tokenizer.encode(s)) for s in X_text)
                + 1
            )
        else:
            max_len_input = -1
            for v in self.verbalizer.values():
                max_len_input = max(
                    max_len_input,
                    max(
                        [
                            len(self.tokenizer.encode(template % (s, v)))
                            for s in X_text[:1000]
                        ]
                    ),
                )
        try:
            max_total_len = self.model.config.n_positions
        except:
            max_total_len = self.model.config.max_position_embeddings
        max_len_prompt = max_total_len - max_len_input
        if (
            True
        ):  # 'dbpedia' in self.args.dataset_name or max_len_prompt < 0: # dbpedia
            print("max len prompt less than 0, truncating to the left")
            max_len_input = -1
            for v in self.verbalizer.values():
                a = [
                    len(self.tokenizer.encode(template % (s, v)))
                    for s in X_text[:1000]
                ]
                max_len_input = max(max_len_input, np.percentile(a, 95))
        max_len_input = int(math.ceil(max_len_input))
        max_len_prompt = max_total_len - max_len_input
        self.max_len_input = max_len_input
        print(
            f"max_len_prompt: {max_len_prompt}, max_len_input: {max_len_input}")
        assert max_len_prompt > 0
        inputs = self.tokenizer(
            [
                p,
            ],
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=max_len_prompt,
            return_attention_mask=True,
        ).to(self.model.device)

        # shape is (batch_size, seq_len, vocab_size)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs["past_key_values"]

    def _get_logit_for_target_tokens_batched(
        self, prompts: List[str], target_token_ids: List[int], batch_size: int = 64
    ) -> np.ndarray[float]:
        """Get logits for each target token
        This can fail when token_output_ids represents multiple tokens
        So things get mapped to the same id representing "unknown"
        """
        logit_targets_list = []
        batch_num = 0

        try:
            max_total_len = self.model.config.n_positions
        except:
            max_total_len = self.model.config.max_position_embeddings

        pbar = tqdm.tqdm(
            total=len(prompts),
            leave=False,
            desc="getting dataset predictions for top prompt",
            colour="red",
        )
        while True:
            batch_start = batch_num * batch_size
            batch_end = (batch_num + 1) * batch_size
            batch_num += 1
            pbar.update(batch_size)
            if batch_start >= len(prompts):
                return np.array(logit_targets_list)

            prompts_batch = prompts[batch_start:batch_end]
            self.tokenizer.padding = True
            self.tokenizer.truncation_side = "left"
            self.tokenizer.pad_token = self.tokenizer.eos_token
            inputs = self.tokenizer(
                prompts_batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                return_attention_mask=True,
                max_length=max_total_len,
            ).to(self.model.device)

            # shape is (batch_size, seq_len, vocab_size)
            with torch.no_grad():
                logits = self.model(**inputs)["logits"]

            token_output_positions = inputs["attention_mask"].sum(axis=1)
            for i in range(len(prompts_batch)):
                token_output_position = token_output_positions[i].item() - 1
                logit_targets_list.append(
                    [
                        logits[i, token_output_position,
                               token_output_id].item()
                        for token_output_id in target_token_ids
                    ]
                )

    def _get_logit_for_target_tokens_batched_with_cache(
        self,
        past_key_values,
        prompts: List[str],
        target_token_ids: List[int],
        batch_size: int = 64,
    ) -> np.ndarray[float]:
        """Get logits for each target token
        This can fail when token_output_ids represents multiple tokens
        So things get mapped to the same id representing "unknown"
        """
        logit_targets_list = []
        batch_num = 0

        pbar = tqdm.tqdm(
            total=len(prompts), leave=False, desc="getting predictions", colour="red"
        )

        past_key_values_new = []
        for i in range(len(past_key_values)):
            past_key_values_new.append(
                [
                    past_key_values[i][0].expand(batch_size, -1, -1, -1),
                    past_key_values[i][1].expand(batch_size, -1, -1, -1),
                ]
            )
        while True:
            batch_start = batch_num * batch_size
            batch_end = (batch_num + 1) * batch_size
            batch_num += 1
            pbar.update(batch_size)
            if batch_start >= len(prompts):
                return np.array(logit_targets_list)

            prompts_batch = prompts[batch_start:batch_end]
            if len(prompts_batch) != past_key_values_new[0][0].shape[0]:
                for i in range(len(past_key_values)):
                    past_key_values_new[i] = [
                        past_key_values[i][0].expand(
                            len(prompts_batch), -1, -1, -1),
                        past_key_values[i][1].expand(
                            len(prompts_batch), -1, -1, -1),
                    ]
            self.tokenizer.padding = True
            self.tokenizer.pad_token = self.tokenizer.eos_token
            inputs = self.tokenizer(
                prompts_batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_len_input,
                return_attention_mask=True,
            ).to(self.model.device)

            attention_mask = inputs["attention_mask"]
            attention_mask = torch.cat(
                (
                    attention_mask.new_zeros(
                        len(prompts_batch), past_key_values[0][0].shape[-2]
                    ).fill_(1),
                    attention_mask,
                ),
                dim=-1,
            )
            inputs["attention_mask"] = attention_mask

            # shape is (batch_size, seq_len, vocab_size)
            with torch.no_grad():
                outputs = self.model(
                    **inputs, past_key_values=past_key_values_new)
            logits = outputs["logits"]
            token_output_positions = (
                inputs["attention_mask"].sum(
                    axis=1) - past_key_values[0][0].shape[-2]
            )
            for i in range(len(prompts_batch)):
                token_output_position = token_output_positions[i].item() - 1
                logit_targets_list.append(
                    [
                        logits[i, token_output_position,
                               token_output_id].item()
                        for token_output_id in target_token_ids
                    ]
                )

    def _get_first_token_id(self, prompt: str) -> str:
        """Get first token id in prompt (after special tokens).

        Need to strip special tokens for LLAMA so we don't get a special space token at the beginning.
        """
        if "llama" in self.checkpoint.lower():
            prompt = prompt.lstrip()

        tokens = self.tokenizer(prompt)["input_ids"]
        tokens = [t for t in tokens if t not in self.tokenizer.all_special_ids]
        return tokens[0]

Classes

class PromptStump (args=None, prompt: str = None, tokenizer=None, prompt_template: str = '{example}{prompt}', cache_key_values: bool = False, verbose: bool = True, model: transformers.models.auto.modeling_auto.AutoModelForCausalLM = None, checkpoint: str = 'EleutherAI/gpt-j-6B', verbalizer: Dict[int, str] = {0: ' Negative.', 1: ' Positive.'}, batch_size: int = 1)

Given a prompt, extract its outputs

Params

args: contains some parameters passed through namespace (can ignore these) prompt: str the prompt to use (optional) prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example}) checkpoint: str the underlying model used for prediction model: AutoModelForCausalLM if this is passed, will override checkpoint

Expand source code
class PromptStump:
    def __init__(
        self,
        args=None,
        prompt: str = None,
        tokenizer=None,
        prompt_template: str = "{example}{prompt}",
        cache_key_values: bool = False,
        verbose: bool = True,
        model: AutoModelForCausalLM = None,
        checkpoint: str = "EleutherAI/gpt-j-6B",
        verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."},
        batch_size: int = 1,
    ):
        """Given a prompt, extract its outputs

        Params
        ------
        args: contains some parameters passed through namespace (can ignore these)
        prompt: str
            the prompt to use (optional)
        prompt_template: str
            template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example}
            or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output:
        cache_key_values: bool
            Whether to cache key values (only possible when prompt does not start with {example})
        checkpoint: str
            the underlying model used for prediction
        model: AutoModelForCausalLM
            if this is passed, will override checkpoint
        """
        if args is None:

            class placeholder:
                prompt_source = None
                template_data_demonstrations = None
                dataset_name = ""

            self.args = placeholder()
        else:
            self.args = args
        self.prompt = prompt
        self.prompt_template = prompt_template
        self.cache_key_values = cache_key_values
        self.verbose = verbose
        self.checkpoint = checkpoint
        self.model = model
        if tokenizer is None:
            self.tokenizer = imodelsx.llm.load_tokenizer(checkpoint)
        else:
            self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.verbalizer = verbalizer

        if self.verbose:
            logging.info(f"Loading model {self.checkpoint}")

    def predict(self, X_text: List[str]) -> np.ndarray[int]:
        preds_proba = self.predict_proba(X_text)
        return np.argmax(preds_proba, axis=1)

    def predict_with_cache(self, X_text: List[str], past_key_values) -> np.ndarray[int]:
        preds_proba = self.predict_proba_with_cache(X_text, past_key_values)
        return np.argmax(preds_proba, axis=1)

    def predict_proba(self, X_text: List[str]) -> np.ndarray[float]:
        target_strs = list(self.verbalizer.values())

        # only predict based on first token of output string
        target_token_ids = list(map(self._get_first_token_id, target_strs))
        assert len(set(target_token_ids)) == len(
            set(target_strs)
        ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"
        text_inputs = [self.prompt_template.format(
            example=x, prompt=self.prompt) for x in X_text]

        preds = self._get_logit_for_target_tokens_batched(
            text_inputs,
            target_token_ids,
            batch_size=self.batch_size,
        )
        assert preds.shape == (len(X_text), len(target_token_ids)), (
            "preds shape was"
            + str(preds.shape)
            + " but should have been "
            + str((len(X_text), len(target_token_ids)))
        )

        # return the class with the highest logit
        return softmax(preds, axis=1)

    def predict_proba_with_cache(
        self, X_text: List[str], past_key_values
    ) -> np.ndarray[float]:
        target_strs = list(self.verbalizer.values())

        # only predict based on first token of output string
        target_token_ids = list(map(self._get_first_token_id, target_strs))
        assert len(set(target_token_ids)) == len(
            set(target_strs)
        ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"

        text_inputs = [self.prompt_template.format(
            example=x, prompt=self.prompt) for x in X_text]

        preds = self._get_logit_for_target_tokens_batched_with_cache(
            past_key_values,
            text_inputs,
            target_token_ids,
            batch_size=self.batch_size,
        )

        assert preds.shape == (len(X_text), len(target_token_ids)), (
            "preds shape was"
            + str(preds.shape)
            + " but should have been "
            + str((len(X_text), len(target_token_ids)))
        )

        # return the class with the highest logit
        return softmax(preds, axis=1)

    def calc_key_values(self, X_text: List[str]):
        # only predict based on first token of output string
        self.tokenizer.truncation_side = "left"
        self.tokenizer.padding = True

        self.tokenizer.pad_token = self.tokenizer.eos_token

        p = self.prompt
        template = self.args.template_data_demonstrations
        if self.args.dataset_name.startswith("knnp__"):
            max_len_verb = max(
                len(self.tokenizer.encode(v)) for v in self.verbalizer.values()
            )
            max_len_input = (
                max_len_verb
                + max(len(self.tokenizer.encode(s)) for s in X_text)
                + 1
            )
        else:
            max_len_input = -1
            for v in self.verbalizer.values():
                max_len_input = max(
                    max_len_input,
                    max(
                        [
                            len(self.tokenizer.encode(template % (s, v)))
                            for s in X_text[:1000]
                        ]
                    ),
                )
        try:
            max_total_len = self.model.config.n_positions
        except:
            max_total_len = self.model.config.max_position_embeddings
        max_len_prompt = max_total_len - max_len_input
        if (
            True
        ):  # 'dbpedia' in self.args.dataset_name or max_len_prompt < 0: # dbpedia
            print("max len prompt less than 0, truncating to the left")
            max_len_input = -1
            for v in self.verbalizer.values():
                a = [
                    len(self.tokenizer.encode(template % (s, v)))
                    for s in X_text[:1000]
                ]
                max_len_input = max(max_len_input, np.percentile(a, 95))
        max_len_input = int(math.ceil(max_len_input))
        max_len_prompt = max_total_len - max_len_input
        self.max_len_input = max_len_input
        print(
            f"max_len_prompt: {max_len_prompt}, max_len_input: {max_len_input}")
        assert max_len_prompt > 0
        inputs = self.tokenizer(
            [
                p,
            ],
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=max_len_prompt,
            return_attention_mask=True,
        ).to(self.model.device)

        # shape is (batch_size, seq_len, vocab_size)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs["past_key_values"]

    def _get_logit_for_target_tokens_batched(
        self, prompts: List[str], target_token_ids: List[int], batch_size: int = 64
    ) -> np.ndarray[float]:
        """Get logits for each target token
        This can fail when token_output_ids represents multiple tokens
        So things get mapped to the same id representing "unknown"
        """
        logit_targets_list = []
        batch_num = 0

        try:
            max_total_len = self.model.config.n_positions
        except:
            max_total_len = self.model.config.max_position_embeddings

        pbar = tqdm.tqdm(
            total=len(prompts),
            leave=False,
            desc="getting dataset predictions for top prompt",
            colour="red",
        )
        while True:
            batch_start = batch_num * batch_size
            batch_end = (batch_num + 1) * batch_size
            batch_num += 1
            pbar.update(batch_size)
            if batch_start >= len(prompts):
                return np.array(logit_targets_list)

            prompts_batch = prompts[batch_start:batch_end]
            self.tokenizer.padding = True
            self.tokenizer.truncation_side = "left"
            self.tokenizer.pad_token = self.tokenizer.eos_token
            inputs = self.tokenizer(
                prompts_batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                return_attention_mask=True,
                max_length=max_total_len,
            ).to(self.model.device)

            # shape is (batch_size, seq_len, vocab_size)
            with torch.no_grad():
                logits = self.model(**inputs)["logits"]

            token_output_positions = inputs["attention_mask"].sum(axis=1)
            for i in range(len(prompts_batch)):
                token_output_position = token_output_positions[i].item() - 1
                logit_targets_list.append(
                    [
                        logits[i, token_output_position,
                               token_output_id].item()
                        for token_output_id in target_token_ids
                    ]
                )

    def _get_logit_for_target_tokens_batched_with_cache(
        self,
        past_key_values,
        prompts: List[str],
        target_token_ids: List[int],
        batch_size: int = 64,
    ) -> np.ndarray[float]:
        """Get logits for each target token
        This can fail when token_output_ids represents multiple tokens
        So things get mapped to the same id representing "unknown"
        """
        logit_targets_list = []
        batch_num = 0

        pbar = tqdm.tqdm(
            total=len(prompts), leave=False, desc="getting predictions", colour="red"
        )

        past_key_values_new = []
        for i in range(len(past_key_values)):
            past_key_values_new.append(
                [
                    past_key_values[i][0].expand(batch_size, -1, -1, -1),
                    past_key_values[i][1].expand(batch_size, -1, -1, -1),
                ]
            )
        while True:
            batch_start = batch_num * batch_size
            batch_end = (batch_num + 1) * batch_size
            batch_num += 1
            pbar.update(batch_size)
            if batch_start >= len(prompts):
                return np.array(logit_targets_list)

            prompts_batch = prompts[batch_start:batch_end]
            if len(prompts_batch) != past_key_values_new[0][0].shape[0]:
                for i in range(len(past_key_values)):
                    past_key_values_new[i] = [
                        past_key_values[i][0].expand(
                            len(prompts_batch), -1, -1, -1),
                        past_key_values[i][1].expand(
                            len(prompts_batch), -1, -1, -1),
                    ]
            self.tokenizer.padding = True
            self.tokenizer.pad_token = self.tokenizer.eos_token
            inputs = self.tokenizer(
                prompts_batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_len_input,
                return_attention_mask=True,
            ).to(self.model.device)

            attention_mask = inputs["attention_mask"]
            attention_mask = torch.cat(
                (
                    attention_mask.new_zeros(
                        len(prompts_batch), past_key_values[0][0].shape[-2]
                    ).fill_(1),
                    attention_mask,
                ),
                dim=-1,
            )
            inputs["attention_mask"] = attention_mask

            # shape is (batch_size, seq_len, vocab_size)
            with torch.no_grad():
                outputs = self.model(
                    **inputs, past_key_values=past_key_values_new)
            logits = outputs["logits"]
            token_output_positions = (
                inputs["attention_mask"].sum(
                    axis=1) - past_key_values[0][0].shape[-2]
            )
            for i in range(len(prompts_batch)):
                token_output_position = token_output_positions[i].item() - 1
                logit_targets_list.append(
                    [
                        logits[i, token_output_position,
                               token_output_id].item()
                        for token_output_id in target_token_ids
                    ]
                )

    def _get_first_token_id(self, prompt: str) -> str:
        """Get first token id in prompt (after special tokens).

        Need to strip special tokens for LLAMA so we don't get a special space token at the beginning.
        """
        if "llama" in self.checkpoint.lower():
            prompt = prompt.lstrip()

        tokens = self.tokenizer(prompt)["input_ids"]
        tokens = [t for t in tokens if t not in self.tokenizer.all_special_ids]
        return tokens[0]

Methods

def calc_key_values(self, X_text: List[str])
Expand source code
def calc_key_values(self, X_text: List[str]):
    # only predict based on first token of output string
    self.tokenizer.truncation_side = "left"
    self.tokenizer.padding = True

    self.tokenizer.pad_token = self.tokenizer.eos_token

    p = self.prompt
    template = self.args.template_data_demonstrations
    if self.args.dataset_name.startswith("knnp__"):
        max_len_verb = max(
            len(self.tokenizer.encode(v)) for v in self.verbalizer.values()
        )
        max_len_input = (
            max_len_verb
            + max(len(self.tokenizer.encode(s)) for s in X_text)
            + 1
        )
    else:
        max_len_input = -1
        for v in self.verbalizer.values():
            max_len_input = max(
                max_len_input,
                max(
                    [
                        len(self.tokenizer.encode(template % (s, v)))
                        for s in X_text[:1000]
                    ]
                ),
            )
    try:
        max_total_len = self.model.config.n_positions
    except:
        max_total_len = self.model.config.max_position_embeddings
    max_len_prompt = max_total_len - max_len_input
    if (
        True
    ):  # 'dbpedia' in self.args.dataset_name or max_len_prompt < 0: # dbpedia
        print("max len prompt less than 0, truncating to the left")
        max_len_input = -1
        for v in self.verbalizer.values():
            a = [
                len(self.tokenizer.encode(template % (s, v)))
                for s in X_text[:1000]
            ]
            max_len_input = max(max_len_input, np.percentile(a, 95))
    max_len_input = int(math.ceil(max_len_input))
    max_len_prompt = max_total_len - max_len_input
    self.max_len_input = max_len_input
    print(
        f"max_len_prompt: {max_len_prompt}, max_len_input: {max_len_input}")
    assert max_len_prompt > 0
    inputs = self.tokenizer(
        [
            p,
        ],
        return_tensors="pt",
        padding=False,
        truncation=True,
        max_length=max_len_prompt,
        return_attention_mask=True,
    ).to(self.model.device)

    # shape is (batch_size, seq_len, vocab_size)
    with torch.no_grad():
        outputs = self.model(**inputs)
    return outputs["past_key_values"]
def predict(self, X_text: List[str]) ‑> numpy.ndarray[int]
Expand source code
def predict(self, X_text: List[str]) -> np.ndarray[int]:
    preds_proba = self.predict_proba(X_text)
    return np.argmax(preds_proba, axis=1)
def predict_proba(self, X_text: List[str]) ‑> numpy.ndarray[float]
Expand source code
def predict_proba(self, X_text: List[str]) -> np.ndarray[float]:
    target_strs = list(self.verbalizer.values())

    # only predict based on first token of output string
    target_token_ids = list(map(self._get_first_token_id, target_strs))
    assert len(set(target_token_ids)) == len(
        set(target_strs)
    ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"
    text_inputs = [self.prompt_template.format(
        example=x, prompt=self.prompt) for x in X_text]

    preds = self._get_logit_for_target_tokens_batched(
        text_inputs,
        target_token_ids,
        batch_size=self.batch_size,
    )
    assert preds.shape == (len(X_text), len(target_token_ids)), (
        "preds shape was"
        + str(preds.shape)
        + " but should have been "
        + str((len(X_text), len(target_token_ids)))
    )

    # return the class with the highest logit
    return softmax(preds, axis=1)
def predict_proba_with_cache(self, X_text: List[str], past_key_values) ‑> numpy.ndarray[float]
Expand source code
def predict_proba_with_cache(
    self, X_text: List[str], past_key_values
) -> np.ndarray[float]:
    target_strs = list(self.verbalizer.values())

    # only predict based on first token of output string
    target_token_ids = list(map(self._get_first_token_id, target_strs))
    assert len(set(target_token_ids)) == len(
        set(target_strs)
    ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"

    text_inputs = [self.prompt_template.format(
        example=x, prompt=self.prompt) for x in X_text]

    preds = self._get_logit_for_target_tokens_batched_with_cache(
        past_key_values,
        text_inputs,
        target_token_ids,
        batch_size=self.batch_size,
    )

    assert preds.shape == (len(X_text), len(target_token_ids)), (
        "preds shape was"
        + str(preds.shape)
        + " but should have been "
        + str((len(X_text), len(target_token_ids)))
    )

    # return the class with the highest logit
    return softmax(preds, axis=1)
def predict_with_cache(self, X_text: List[str], past_key_values) ‑> numpy.ndarray[int]
Expand source code
def predict_with_cache(self, X_text: List[str], past_key_values) -> np.ndarray[int]:
    preds_proba = self.predict_proba_with_cache(X_text, past_key_values)
    return np.argmax(preds_proba, axis=1)