Module imodelsx.treeprompt.stump

Classes

class PromptStump (args=None,
prompt: str = None,
tokenizer=None,
prompt_template: str = '{example}{prompt}',
cache_key_values: bool = False,
verbose: bool = True,
model: transformers.models.auto.modeling_auto.AutoModelForCausalLM = None,
checkpoint: str = 'EleutherAI/gpt-j-6B',
verbalizer: Dict[int, str] = {0: ' Negative.', 1: ' Positive.'},
batch_size: int = 1)
Expand source code
class PromptStump:
    def __init__(
        self,
        args=None,
        prompt: str = None,
        tokenizer=None,
        prompt_template: str = "{example}{prompt}",
        cache_key_values: bool = False,
        verbose: bool = True,
        model: AutoModelForCausalLM = None,
        checkpoint: str = "EleutherAI/gpt-j-6B",
        verbalizer: Dict[int, str] = {0: " Negative.", 1: " Positive."},
        batch_size: int = 1,
    ):
        """Given a prompt, extract its outputs

        Params
        ------
        args: contains some parameters passed through namespace (can ignore these)
        prompt: str
            the prompt to use (optional)
        prompt_template: str
            template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example}
            or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output:
        cache_key_values: bool
            Whether to cache key values (only possible when prompt does not start with {example})
        checkpoint: str
            the underlying model used for prediction
        model: AutoModelForCausalLM
            if this is passed, will override checkpoint
        """
        if args is None:

            class placeholder:
                prompt_source = None
                template_data_demonstrations = None
                dataset_name = ""

            self.args = placeholder()
        else:
            self.args = args
        self.prompt = prompt
        self.prompt_template = prompt_template
        self.cache_key_values = cache_key_values
        self.verbose = verbose
        self.checkpoint = checkpoint
        self.model = model
        if tokenizer is None:
            self.tokenizer = imodelsx.llm.load_tokenizer(checkpoint)
        else:
            self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.verbalizer = verbalizer

        if self.verbose:
            logging.info(f"Loading model {self.checkpoint}")

    def predict(self, X_text: List[str]) -> np.ndarray[int]:
        preds_proba = self.predict_proba(X_text)
        return np.argmax(preds_proba, axis=1)

    def predict_with_cache(self, X_text: List[str], past_key_values) -> np.ndarray[int]:
        preds_proba = self.predict_proba_with_cache(X_text, past_key_values)
        return np.argmax(preds_proba, axis=1)

    def predict_proba(self, X_text: List[str]) -> np.ndarray[float]:
        target_strs = list(self.verbalizer.values())

        # only predict based on first token of output string
        target_token_ids = list(map(self._get_first_token_id, target_strs))
        assert len(set(target_token_ids)) == len(
            set(target_strs)
        ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"
        text_inputs = [self.prompt_template.format(
            example=x, prompt=self.prompt) for x in X_text]

        preds = self._get_logit_for_target_tokens_batched(
            text_inputs,
            target_token_ids,
            batch_size=self.batch_size,
        )
        assert preds.shape == (len(X_text), len(target_token_ids)), (
            "preds shape was"
            + str(preds.shape)
            + " but should have been "
            + str((len(X_text), len(target_token_ids)))
        )

        # return the class with the highest logit
        return softmax(preds, axis=1)

    def predict_proba_with_cache(
        self, X_text: List[str], past_key_values
    ) -> np.ndarray[float]:
        target_strs = list(self.verbalizer.values())

        # only predict based on first token of output string
        target_token_ids = list(map(self._get_first_token_id, target_strs))
        assert len(set(target_token_ids)) == len(
            set(target_strs)
        ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"

        text_inputs = [self.prompt_template.format(
            example=x, prompt=self.prompt) for x in X_text]

        preds = self._get_logit_for_target_tokens_batched_with_cache(
            past_key_values,
            text_inputs,
            target_token_ids,
            batch_size=self.batch_size,
        )

        assert preds.shape == (len(X_text), len(target_token_ids)), (
            "preds shape was"
            + str(preds.shape)
            + " but should have been "
            + str((len(X_text), len(target_token_ids)))
        )

        # return the class with the highest logit
        return softmax(preds, axis=1)

    def calc_key_values(self, X_text: List[str]):
        # only predict based on first token of output string
        self.tokenizer.truncation_side = "left"
        self.tokenizer.padding = True

        self.tokenizer.pad_token = self.tokenizer.eos_token

        p = self.prompt
        template = self.args.template_data_demonstrations
        if self.args.dataset_name.startswith("knnp__"):
            max_len_verb = max(
                len(self.tokenizer.encode(v)) for v in self.verbalizer.values()
            )
            max_len_input = (
                max_len_verb
                + max(len(self.tokenizer.encode(s)) for s in X_text)
                + 1
            )
        else:
            max_len_input = -1
            for v in self.verbalizer.values():
                max_len_input = max(
                    max_len_input,
                    max(
                        [
                            len(self.tokenizer.encode(template % (s, v)))
                            for s in X_text[:1000]
                        ]
                    ),
                )
        try:
            max_total_len = self.model.config.n_positions
        except:
            max_total_len = self.model.config.max_position_embeddings
        max_len_prompt = max_total_len - max_len_input
        if (
            True
        ):  # 'dbpedia' in self.args.dataset_name or max_len_prompt < 0: # dbpedia
            print("max len prompt less than 0, truncating to the left")
            max_len_input = -1
            for v in self.verbalizer.values():
                a = [
                    len(self.tokenizer.encode(template % (s, v)))
                    for s in X_text[:1000]
                ]
                max_len_input = max(max_len_input, np.percentile(a, 95))
        max_len_input = int(math.ceil(max_len_input))
        max_len_prompt = max_total_len - max_len_input
        self.max_len_input = max_len_input
        print(
            f"max_len_prompt: {max_len_prompt}, max_len_input: {max_len_input}")
        assert max_len_prompt > 0
        inputs = self.tokenizer(
            [
                p,
            ],
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=max_len_prompt,
            return_attention_mask=True,
        ).to(self.model.device)

        # shape is (batch_size, seq_len, vocab_size)
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs["past_key_values"]

    def _get_logit_for_target_tokens_batched(
        self, prompts: List[str], target_token_ids: List[int], batch_size: int = 64
    ) -> np.ndarray[float]:
        """Get logits for each target token
        This can fail when token_output_ids represents multiple tokens
        So things get mapped to the same id representing "unknown"
        """
        logit_targets_list = []
        batch_num = 0

        try:
            max_total_len = self.model.config.n_positions
        except:
            max_total_len = self.model.config.max_position_embeddings

        pbar = tqdm.tqdm(
            total=len(prompts),
            leave=False,
            desc="getting dataset predictions for top prompt",
            colour="red",
        )
        while True:
            batch_start = batch_num * batch_size
            batch_end = (batch_num + 1) * batch_size
            batch_num += 1
            pbar.update(batch_size)
            if batch_start >= len(prompts):
                return np.array(logit_targets_list)

            prompts_batch = prompts[batch_start:batch_end]
            self.tokenizer.padding = True
            self.tokenizer.truncation_side = "left"
            self.tokenizer.pad_token = self.tokenizer.eos_token
            inputs = self.tokenizer(
                prompts_batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                return_attention_mask=True,
                max_length=max_total_len,
            ).to(self.model.device)

            # shape is (batch_size, seq_len, vocab_size)
            with torch.no_grad():
                logits = self.model(**inputs)["logits"]

            token_output_positions = inputs["attention_mask"].sum(axis=1)
            for i in range(len(prompts_batch)):
                token_output_position = token_output_positions[i].item() - 1
                logit_targets_list.append(
                    [
                        logits[i, token_output_position,
                               token_output_id].item()
                        for token_output_id in target_token_ids
                    ]
                )

    def _get_logit_for_target_tokens_batched_with_cache(
        self,
        past_key_values,
        prompts: List[str],
        target_token_ids: List[int],
        batch_size: int = 64,
    ) -> np.ndarray[float]:
        """Get logits for each target token
        This can fail when token_output_ids represents multiple tokens
        So things get mapped to the same id representing "unknown"
        """
        logit_targets_list = []
        batch_num = 0

        pbar = tqdm.tqdm(
            total=len(prompts), leave=False, desc="getting predictions", colour="red"
        )

        past_key_values_new = []
        for i in range(len(past_key_values)):
            past_key_values_new.append(
                [
                    past_key_values[i][0].expand(batch_size, -1, -1, -1),
                    past_key_values[i][1].expand(batch_size, -1, -1, -1),
                ]
            )
        while True:
            batch_start = batch_num * batch_size
            batch_end = (batch_num + 1) * batch_size
            batch_num += 1
            pbar.update(batch_size)
            if batch_start >= len(prompts):
                return np.array(logit_targets_list)

            prompts_batch = prompts[batch_start:batch_end]
            if len(prompts_batch) != past_key_values_new[0][0].shape[0]:
                for i in range(len(past_key_values)):
                    past_key_values_new[i] = [
                        past_key_values[i][0].expand(
                            len(prompts_batch), -1, -1, -1),
                        past_key_values[i][1].expand(
                            len(prompts_batch), -1, -1, -1),
                    ]
            self.tokenizer.padding = True
            self.tokenizer.pad_token = self.tokenizer.eos_token
            inputs = self.tokenizer(
                prompts_batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_len_input,
                return_attention_mask=True,
            ).to(self.model.device)

            attention_mask = inputs["attention_mask"]
            attention_mask = torch.cat(
                (
                    attention_mask.new_zeros(
                        len(prompts_batch), past_key_values[0][0].shape[-2]
                    ).fill_(1),
                    attention_mask,
                ),
                dim=-1,
            )
            inputs["attention_mask"] = attention_mask

            # shape is (batch_size, seq_len, vocab_size)
            with torch.no_grad():
                outputs = self.model(
                    **inputs, past_key_values=past_key_values_new)
            logits = outputs["logits"]
            token_output_positions = (
                inputs["attention_mask"].sum(
                    axis=1) - past_key_values[0][0].shape[-2]
            )
            for i in range(len(prompts_batch)):
                token_output_position = token_output_positions[i].item() - 1
                logit_targets_list.append(
                    [
                        logits[i, token_output_position,
                               token_output_id].item()
                        for token_output_id in target_token_ids
                    ]
                )

    def _get_first_token_id(self, prompt: str) -> str:
        """Get first token id in prompt (after special tokens).

        Need to strip special tokens for LLAMA so we don't get a special space token at the beginning.
        """
        if "llama" in self.checkpoint.lower():
            prompt = prompt.lstrip()

        tokens = self.tokenizer(prompt)["input_ids"]
        tokens = [t for t in tokens if t not in self.tokenizer.all_special_ids]
        return tokens[0]

Given a prompt, extract its outputs

Params

args: contains some parameters passed through namespace (can ignore these) prompt: str the prompt to use (optional) prompt_template: str template for the prompt, for different prompt styles (e.g. few-shot), may want to place {prompt} before {example} or you may want to add some text before the verbalizer, e.g. {example}{prompt} Output: cache_key_values: bool Whether to cache key values (only possible when prompt does not start with {example}) checkpoint: str the underlying model used for prediction model: AutoModelForCausalLM if this is passed, will override checkpoint

Methods

def calc_key_values(self, X_text: List[str])
Expand source code
def calc_key_values(self, X_text: List[str]):
    # only predict based on first token of output string
    self.tokenizer.truncation_side = "left"
    self.tokenizer.padding = True

    self.tokenizer.pad_token = self.tokenizer.eos_token

    p = self.prompt
    template = self.args.template_data_demonstrations
    if self.args.dataset_name.startswith("knnp__"):
        max_len_verb = max(
            len(self.tokenizer.encode(v)) for v in self.verbalizer.values()
        )
        max_len_input = (
            max_len_verb
            + max(len(self.tokenizer.encode(s)) for s in X_text)
            + 1
        )
    else:
        max_len_input = -1
        for v in self.verbalizer.values():
            max_len_input = max(
                max_len_input,
                max(
                    [
                        len(self.tokenizer.encode(template % (s, v)))
                        for s in X_text[:1000]
                    ]
                ),
            )
    try:
        max_total_len = self.model.config.n_positions
    except:
        max_total_len = self.model.config.max_position_embeddings
    max_len_prompt = max_total_len - max_len_input
    if (
        True
    ):  # 'dbpedia' in self.args.dataset_name or max_len_prompt < 0: # dbpedia
        print("max len prompt less than 0, truncating to the left")
        max_len_input = -1
        for v in self.verbalizer.values():
            a = [
                len(self.tokenizer.encode(template % (s, v)))
                for s in X_text[:1000]
            ]
            max_len_input = max(max_len_input, np.percentile(a, 95))
    max_len_input = int(math.ceil(max_len_input))
    max_len_prompt = max_total_len - max_len_input
    self.max_len_input = max_len_input
    print(
        f"max_len_prompt: {max_len_prompt}, max_len_input: {max_len_input}")
    assert max_len_prompt > 0
    inputs = self.tokenizer(
        [
            p,
        ],
        return_tensors="pt",
        padding=False,
        truncation=True,
        max_length=max_len_prompt,
        return_attention_mask=True,
    ).to(self.model.device)

    # shape is (batch_size, seq_len, vocab_size)
    with torch.no_grad():
        outputs = self.model(**inputs)
    return outputs["past_key_values"]
def predict(self, X_text: List[str]) ‑> numpy.ndarray[int]
Expand source code
def predict(self, X_text: List[str]) -> np.ndarray[int]:
    preds_proba = self.predict_proba(X_text)
    return np.argmax(preds_proba, axis=1)
def predict_proba(self, X_text: List[str]) ‑> numpy.ndarray[float]
Expand source code
def predict_proba(self, X_text: List[str]) -> np.ndarray[float]:
    target_strs = list(self.verbalizer.values())

    # only predict based on first token of output string
    target_token_ids = list(map(self._get_first_token_id, target_strs))
    assert len(set(target_token_ids)) == len(
        set(target_strs)
    ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"
    text_inputs = [self.prompt_template.format(
        example=x, prompt=self.prompt) for x in X_text]

    preds = self._get_logit_for_target_tokens_batched(
        text_inputs,
        target_token_ids,
        batch_size=self.batch_size,
    )
    assert preds.shape == (len(X_text), len(target_token_ids)), (
        "preds shape was"
        + str(preds.shape)
        + " but should have been "
        + str((len(X_text), len(target_token_ids)))
    )

    # return the class with the highest logit
    return softmax(preds, axis=1)
def predict_proba_with_cache(self, X_text: List[str], past_key_values) ‑> numpy.ndarray[float]
Expand source code
def predict_proba_with_cache(
    self, X_text: List[str], past_key_values
) -> np.ndarray[float]:
    target_strs = list(self.verbalizer.values())

    # only predict based on first token of output string
    target_token_ids = list(map(self._get_first_token_id, target_strs))
    assert len(set(target_token_ids)) == len(
        set(target_strs)
    ), f"error: target_token_ids {set(target_token_ids)} not unique to target strings {set(target_strs)}"

    text_inputs = [self.prompt_template.format(
        example=x, prompt=self.prompt) for x in X_text]

    preds = self._get_logit_for_target_tokens_batched_with_cache(
        past_key_values,
        text_inputs,
        target_token_ids,
        batch_size=self.batch_size,
    )

    assert preds.shape == (len(X_text), len(target_token_ids)), (
        "preds shape was"
        + str(preds.shape)
        + " but should have been "
        + str((len(X_text), len(target_token_ids)))
    )

    # return the class with the highest logit
    return softmax(preds, axis=1)
def predict_with_cache(self, X_text: List[str], past_key_values) ‑> numpy.ndarray[int]
Expand source code
def predict_with_cache(self, X_text: List[str], past_key_values) -> np.ndarray[int]:
    preds_proba = self.predict_proba_with_cache(X_text, past_key_values)
    return np.argmax(preds_proba, axis=1)