Module imodelsx.iprompt.ipromptx

Expand source code
from typing import Any, Dict, Iterable, List, Optional, Tuple

import argparse
import collections
import os
import random

import torch
import transformers

from imodelsx.iprompt.autoprompt import AutoPrompt
from imodelsx.iprompt.hotflip import HotFlip
from imodelsx.iprompt.utils import device, PrefixLoss, PrefixModel, PrefixPool


"""
Explaining Patterns in Data with Language Models via Interpretable Autoprompting

Chandan Singh*, John X. Morris*, Jyoti Aneja, Alexander M. Rush, Jianfeng Gao
https://arxiv.org/abs/2210.01848
"""


class iPrompt(AutoPrompt):
    def __init__(
        self,
        loss_func: PrefixLoss,
        model: transformers.PreTrainedModel,
        tokenizer: transformers.PreTrainedTokenizer,
        preprefix_str: str = '',
        prefix_before_input: bool = True,
        pop_criterion: str = 'loss',
        pop_topk_strategy: str = 'different_start_token',
        pop_size: int = 8,
        num_mutations: int = 4,
        num_random_generations: int = 4,
        generation_repetition_penalty: float = 2.0,
        generation_temp: float = 1.0,
        generation_top_p: float = 1.0,
        do_final_reranking: bool = False,
        early_stopping_steps: int = -1,
        num_learned_tokens: int = 1,
        max_length: int = 128,
        verbose: int = 0,
        llm_float16: bool = True,
        generation_checkpoint: str = '',
        n_shots: int = 1,
        single_shot_loss: bool = True,
        llm_candidate_regeneration_prompt_start: str = 'Data:',
        llm_candidate_regeneration_prompt_end: str = 'Prompt:',
    ):
        args = argparse.Namespace()
        args.prefix_before_input = prefix_before_input
        args.num_learned_tokens = num_learned_tokens
        args.hotflip_num_candidates = None
        args.autoprompt_init_strategy = None
        args.save_dir_unique = '.'
        args.n_shots = n_shots
        args.single_shot_loss = single_shot_loss
        args.max_length = max_length
        args.iprompt_do_final_reranking = do_final_reranking
        super().__init__(
            args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=''
        )
        self.tokenizer = tokenizer
        self.tokenizer.add_special_tokens = False
        ####################################################################
        # iPrompt-specific parameters
        self._pop_size = pop_size
        self._topk_pop_sample = (self._pop_size + 4) # sample next population from this num of top things. set higher for more randomness.
        self._num_mutations_per_ex = num_mutations # num mutations for each population item
        self._num_random_generations = num_random_generations # extra random examples to throw in there (won't get mutated)
        self._generation_temp = generation_temp
        self._generation_top_p = generation_top_p
        self._generation_repetition_penalty = generation_repetition_penalty # 1 means no penalty
        self._pop_initialized = False
        self._generation_bad_words_ids = [
            self.tokenizer.encode('\n'),
            self.tokenizer.encode('\n\n'),
            self.tokenizer.encode('\n\n\n')
        ]
        ####################################################################
        self.conditioning_strategy = '' # This arg is only used for ablations.
        self.other_generation_model = None
        if generation_checkpoint:
            self.other_generation_model = load_lm_from_checkpoint(
                generation_checkpoint, float16=llm_float16
            )
        ####################################################################
        self._prefix_pool = PrefixPool(
            tokenizer=self.tokenizer,
            criterion=pop_criterion, # 'loss'  # in ['loss', 'acc', 'combined']
            topk_strategy=pop_topk_strategy,
            verbose=verbose,
        )
        # Suff to track for early stopping
        self._early_stopping_steps = early_stopping_steps
        self._last_population = None
        self._steps_since_new_population = 0
        ####################################################################
        self.prefix_ids = None
        if len(preprefix_str):
            self.preprefix_ids = torch.tensor(
                self.tokenizer.encode(preprefix_str, add_special_tokens=False), dtype=int
            ).to(device)
        else:
            self.preprefix_ids = torch.tensor([], dtype=int).to(device)
        
        prompt_str = preprefix_str.lstrip()
        prompt_str = (' ' + prompt_str) if len(prompt_str) else ''
        self._pre_data_token_ids = self._pre_data_token_ids = self.tokenizer(
           f"{llm_candidate_regeneration_prompt_start}\n\n", return_tensors='pt').input_ids.to(device)
        self._post_data_token_ids = self.tokenizer(
           f"\n\n{llm_candidate_regeneration_prompt_end}" + prompt_str, return_tensors='pt').input_ids.to(device)
        
        self.llm_candidate_regeneration_prompt_start = llm_candidate_regeneration_prompt_start
        self.llm_candidate_regeneration_prompt_end = llm_candidate_regeneration_prompt_end
        ####################################################################
        self._iprompt_verbose = verbose
        self._step = 0
        
    
    def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]:
        r = super().serialize(eval_dataloader=eval_dataloader, possible_answer_mask=possible_answer_mask)
        r["topk_pop_sample"] = self._topk_pop_sample
        r["pop_size"] = self._pop_size
        r["num_mutations_per_ex"] = self._num_mutations_per_ex
        r["num_random_generations"] = self._num_random_generations
        r["generation_temp"] = self._generation_temp
        r["generation_top_p"] = self._generation_top_p
        r["generation_repetition_penalty"] = self._generation_repetition_penalty
        r["generation_bad_words_ids"] = self._generation_bad_words_ids
        r["pre_data_prompt_str"] = self.tokenizer.decode(self._pre_data_token_ids.flatten())
        r["post_data_prompt_str"] = self.tokenizer.decode(self._post_data_token_ids.flatten())
        return r
    
    def _initialize_pop_once(self, full_text_ids: torch.Tensor):
        if self._pop_initialized: return

        while len(self._prefix_pool) < self._pop_size:
            conditional_input_ids = random.choice(full_text_ids)[None]
            num_conditional_tokens = conditional_input_ids.numel()
            input_ids = self._generate(
                input_ids=conditional_input_ids,
                num_conditional_tokens=num_conditional_tokens
            ).squeeze()
            assert input_ids.numel() == self._num_tokens
            self._prefix_pool.initialize_prefix(input_ids)

        self._pop_initialized = True
    
    @property
    def _generation_model(self) -> transformers.AutoModelForCausalLM:
        """Returns the model to use for generation.

        We optionally support using different models for generation and discrimination.
        However, by default, we use the same model for both.
        """
        if self.other_generation_model:
            return self.other_generation_model
        else:
            return self.model
    
    def _generate(self, input_ids: torch.Tensor, num_conditional_tokens: int) -> torch.Tensor:
        """Generates some text using the model and preset hparams.

        If `num_conditional_tokens` > 0, generates extra text because there was an additional
        prefix set.
        """
        attention_mask = ~(input_ids == self.tokenizer.pad_token_id)
        assert attention_mask.shape == input_ids.shape

        if self._is_t5:
            output_length = self._num_tokens + 1 # will add pad token
        else:
            output_length = self._num_tokens + num_conditional_tokens
        
        # print("iPrompt._generate", input_ids.shape, "//", self.tokenizer.decode(input_ids[0]))
        
        g = self._generation_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            min_length=output_length,
            max_length=output_length,
            temperature=self._generation_temp,
            top_p=self._generation_top_p,
            repetition_penalty=self._generation_repetition_penalty,
            bad_words_ids=self._generation_bad_words_ids,
            do_sample=True
        )
        
        if self._is_t5:
            assert (g[:, 0] == 0).all()
            g = g[:, 1:]
        else:
            # Split off the conditional part, we only want the prefix part, which
            # starts after the conditional part.
            g = g[:, num_conditional_tokens:]

        if self._iprompt_verbose:
            # Print a random one (but remove padded tokens and newlines)
            idx = random.choice(range(len(input_ids)))
            # idx_attention_mask = torch.cat(
            #     (attention_mask[idx], torch.ones(self._num_tokens).to(device)), dim=0
            # ).bool()
            random_sentence_ids = g[idx]
            # print(">>", self.tokenizer.decode(random_sentence_ids).replace('\n', '\\n'))
        
        return g
    
    def _select_pop_topk(self, k: int, min_occurrences: int = None) -> List[Tuple[int]]:
        return self._prefix_pool.topk(k=k, min_occurrences=min_occurrences)

    def _track_early_stopping(self):
        """Track changes in population to tell when to stop early."""
        __n_early_stop = 5
        population = set(self._select_pop_topk(k=__n_early_stop, min_occurrences=3))
        if (len(population) == __n_early_stop) and (self._last_population == population):
            self._steps_since_new_population += 1
            if self._iprompt_verbose:
                print("self._steps_since_new_population:", self._steps_since_new_population)
        else:
            self._last_population = population
            self._steps_since_new_population = 0
            if self._iprompt_verbose:
                print("new population:", [self.tokenizer.decode(p) for p in sorted(population)])

    def check_early_stop(self) -> bool:
        """Allow prefix models to stop early."""
        if self._early_stopping_steps == -1:
            return False
        return self._steps_since_new_population >= self._early_stopping_steps
    
    def _get_population_and_random_generations(self, full_text_ids: torch.Tensor) -> torch.Tensor:
        population_pool = self._select_pop_topk(k=self._topk_pop_sample)
        # if self._iprompt_verbose:
            # print("population_pool:", [self.tokenizer.decode(p) for p in population_pool])
        population = random.sample(population_pool, self._pop_size)
        population = torch.tensor(population).to(device)

        if self._num_random_generations > 0:
            random_idxs = torch.randint(
                low=0, high=len(full_text_ids), size=(self._num_random_generations,)
            )
            random_full_text_ids = full_text_ids[random_idxs]
            num_conditional_tokens = full_text_ids.shape[1]
            random_population = self._generate(
                input_ids=random_full_text_ids,
                num_conditional_tokens=num_conditional_tokens
            )

            full_population = torch.cat((population, random_population), dim=0)
        else:
            # Support case where _num_random_generations is set to 0.
            full_population = population

        assert full_population.shape == (
            self._pop_size + self._num_random_generations,
            self._num_tokens
        )
        return full_population
    
    def _mutate(self, population_input_ids: torch.Tensor, full_text_ids: torch.Tensor) -> List[torch.Tensor]:
        """Mutates a population of prefixes.

        Truncates to a random place and then generates new options
        to try.

        Args:
            population_input_ids (int torch.Tensor): input IDs for each prefix in population
            full_text_ids (int torch.Tensor): input IDs for each data item in the batch. Intended
                be used to do prefix generation conditioned on data
        """
        assert population_input_ids.shape[1] == self._num_tokens
        input_ids = population_input_ids.repeat((self._num_mutations_per_ex, 1))

        self._roll_before_truncation = False
        if self._roll_before_truncation:
            roll_amount = random.randint(0, self._num_tokens-1)
            input_ids = torch.roll(input_ids, roll_amount, dims=[1])

        truncate_position = random.randint(0, self._num_tokens-1)
        truncated_input_ids = input_ids[:, :truncate_position]

        random_idxs = torch.randint(low=0, high=len(full_text_ids), size=(len(input_ids), ))
        random_full_text_ids = full_text_ids[random_idxs]
        conditional_input_ids = torch.cat((random_full_text_ids, truncated_input_ids), dim=1)

        num_conditional_tokens = full_text_ids.shape[1]
        new_input_ids = self._generate(
            input_ids=conditional_input_ids,
            num_conditional_tokens=num_conditional_tokens
        )
        return new_input_ids
    
    def _score_population(
            self, 
            x_tokenized: transformers.BatchEncoding,
            y_tokenized: transformers.BatchEncoding,
            population_input_ids: torch.Tensor,
            possible_answer_mask: torch.Tensor
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Scores a population of prefixes and updates `self._genetic_pool`."""
        pop_size = len(population_input_ids)
        all_candidate_losses = torch.zeros(pop_size, dtype=float).to(device)
        all_accuracy = torch.zeros(pop_size, dtype=float).to(device)
        all_candidate_n_correct = torch.zeros(pop_size, dtype=int).to(device)
        for i in range(pop_size):
            with torch.no_grad():
                _cand_input_ids, cand_loss, cand_n_correct = (
                    self._compute_loss_with_set_prefix(
                        original_input_ids=x_tokenized.input_ids,
                        next_token_ids=y_tokenized.input_ids,
                        possible_answer_mask=possible_answer_mask,
                        prefix_ids=population_input_ids[i],
                    )
                )
                cand_accuracy = cand_n_correct / len(x_tokenized.input_ids)
            all_candidate_n_correct[i] += cand_n_correct
            all_candidate_losses[i] += cand_loss
            all_accuracy[i] += cand_accuracy
        
        for i in range(pop_size):
            new_pop_input_ids = tuple(population_input_ids[i].cpu().tolist())
            assert len(new_pop_input_ids) == (self._num_tokens)
            self._prefix_pool.update(
                population_input_ids[i], all_candidate_losses[i], all_accuracy[i]
            )
        return all_candidate_losses, all_candidate_n_correct
    
    def _create_full_text_ids(
        self, full_text_input_ids: torch.Tensor) -> torch.Tensor:
        """Creates input for generating explanation.

        Takes tokenized inputs (like: "Input: 7 8 Output: 15")
        and makes a full string that looks like "Data:\n\n Input: .... 15 \n\nExplanation:\n\n",
        using whatever template is defined by pre-data and post-data.
        """
        B = len(full_text_input_ids)
        pre_data = self._pre_data_token_ids.repeat((B, 1)).to(device)  # Like "Data:\n\n"
        post_data = self._post_data_token_ids.repeat((B, 1)).to(device) # Like "\n\nPrompt:"
        output = torch.cat((pre_data, full_text_input_ids, post_data), dim=1)
        return output

    def compute_loss_and_call_backward(
            self,
            x_tokenized: transformers.BatchEncoding,
            y_tokenized: transformers.BatchEncoding,
            possible_answer_mask: torch.Tensor,
            full_text_tokenized: Optional[transformers.BatchEncoding] = None
        ) -> Tuple[torch.Tensor, int]:
        """Returns the loss from the best example in the population

        Note: does not call loss.backward()
        
        Returns:
            loss (float torch.Tensor) -- the loss
            num_correct (int): number of examples where prediction was correct
        """
        self.model.eval()

        # allow for conditioning only on x or y. This is mainly just used for ablations.
        if self.conditioning_strategy == "x_only":
            full_text_tokenized = y_tokenized
        elif self.conditioning_strategy == "y_only":
            full_text_tokenized = y_tokenized
        elif self.conditioning_strategy == "unconditional":
            full_text_tokenized['input_ids'] = torch.full(
                size=(len(y_tokenized), 1),
                fill_value=self.tokenizer.bos_token_id,
                device=device,
            )
            full_text_tokenized['attention_mask'] = torch.ones_like(
                full_text_tokenized['input_ids']
            )

        # logic here is that we want to see a sample multiple times before
        # we actually have a good estimate of its loss.
        num_min_occurrences = 2

        full_text_ids = self._create_full_text_ids(
            full_text_input_ids=full_text_tokenized.input_ids,
        )
        self._initialize_pop_once(full_text_ids=full_text_ids)

        prefix_save_folder = os.path.join(self.args.save_dir_unique, 'prefix')
        df_to_print = self._prefix_pool.print(topk=10, min_occurrences=num_min_occurrences)
        os.makedirs(prefix_save_folder, exist_ok=True)

        log_prefixes = False
        if log_prefixes:
            prefix_out_file = os.path.join(prefix_save_folder, f'prefix_{self._step}.p')
            df_to_print.to_pickle(prefix_out_file)
            print(f'wrote {len(df_to_print)} prefixes to {prefix_out_file}')

        # Grab new population
        population_input_ids = self._get_population_and_random_generations(
            full_text_ids=full_text_ids,
        )

        if self._num_mutations_per_ex > 0:
            mutated_population_input_ids = self._mutate(
                population_input_ids=population_input_ids, full_text_ids=full_text_ids
            )
            full_population_input_ids = torch.cat(
                (population_input_ids, mutated_population_input_ids), dim=0
            )
        else:
            # Support skipping mutation step by stetting _num_mutations_per_ex to 0
            full_population_input_ids = population_input_ids

        # Re-score new guys
        all_candidate_losses, all_candidate_n_correct = self._score_population(
            x_tokenized=x_tokenized,
            y_tokenized=y_tokenized,
            population_input_ids=full_population_input_ids,
            possible_answer_mask=possible_answer_mask
        )

        # Track changes in population to enable early stopping.
        self._track_early_stopping()

        # Reset prefix IDs so that the model can be readily used for eval.
        best_prefix_ids = min(self._prefix_pool._avg_loss, key=self._prefix_pool._avg_loss.get)
        self._set_prefix_ids(torch.tensor(best_prefix_ids).to(device))
        self.prefix_embedding.requires_grad = False

        self._step += 1
        return all_candidate_losses.min(), all_candidate_n_correct.max()
        
    def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
        # 
        # Get candidate IDs for every position.
        # 
        pass

Classes

class iPrompt (loss_func: PrefixLoss, model: transformers.modeling_utils.PreTrainedModel, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, preprefix_str: str = '', prefix_before_input: bool = True, pop_criterion: str = 'loss', pop_topk_strategy: str = 'different_start_token', pop_size: int = 8, num_mutations: int = 4, num_random_generations: int = 4, generation_repetition_penalty: float = 2.0, generation_temp: float = 1.0, generation_top_p: float = 1.0, do_final_reranking: bool = False, early_stopping_steps: int = -1, num_learned_tokens: int = 1, max_length: int = 128, verbose: int = 0, llm_float16: bool = True, generation_checkpoint: str = '', n_shots: int = 1, single_shot_loss: bool = True, llm_candidate_regeneration_prompt_start: str = 'Data:', llm_candidate_regeneration_prompt_end: str = 'Prompt:')

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class iPrompt(AutoPrompt):
    def __init__(
        self,
        loss_func: PrefixLoss,
        model: transformers.PreTrainedModel,
        tokenizer: transformers.PreTrainedTokenizer,
        preprefix_str: str = '',
        prefix_before_input: bool = True,
        pop_criterion: str = 'loss',
        pop_topk_strategy: str = 'different_start_token',
        pop_size: int = 8,
        num_mutations: int = 4,
        num_random_generations: int = 4,
        generation_repetition_penalty: float = 2.0,
        generation_temp: float = 1.0,
        generation_top_p: float = 1.0,
        do_final_reranking: bool = False,
        early_stopping_steps: int = -1,
        num_learned_tokens: int = 1,
        max_length: int = 128,
        verbose: int = 0,
        llm_float16: bool = True,
        generation_checkpoint: str = '',
        n_shots: int = 1,
        single_shot_loss: bool = True,
        llm_candidate_regeneration_prompt_start: str = 'Data:',
        llm_candidate_regeneration_prompt_end: str = 'Prompt:',
    ):
        args = argparse.Namespace()
        args.prefix_before_input = prefix_before_input
        args.num_learned_tokens = num_learned_tokens
        args.hotflip_num_candidates = None
        args.autoprompt_init_strategy = None
        args.save_dir_unique = '.'
        args.n_shots = n_shots
        args.single_shot_loss = single_shot_loss
        args.max_length = max_length
        args.iprompt_do_final_reranking = do_final_reranking
        super().__init__(
            args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=''
        )
        self.tokenizer = tokenizer
        self.tokenizer.add_special_tokens = False
        ####################################################################
        # iPrompt-specific parameters
        self._pop_size = pop_size
        self._topk_pop_sample = (self._pop_size + 4) # sample next population from this num of top things. set higher for more randomness.
        self._num_mutations_per_ex = num_mutations # num mutations for each population item
        self._num_random_generations = num_random_generations # extra random examples to throw in there (won't get mutated)
        self._generation_temp = generation_temp
        self._generation_top_p = generation_top_p
        self._generation_repetition_penalty = generation_repetition_penalty # 1 means no penalty
        self._pop_initialized = False
        self._generation_bad_words_ids = [
            self.tokenizer.encode('\n'),
            self.tokenizer.encode('\n\n'),
            self.tokenizer.encode('\n\n\n')
        ]
        ####################################################################
        self.conditioning_strategy = '' # This arg is only used for ablations.
        self.other_generation_model = None
        if generation_checkpoint:
            self.other_generation_model = load_lm_from_checkpoint(
                generation_checkpoint, float16=llm_float16
            )
        ####################################################################
        self._prefix_pool = PrefixPool(
            tokenizer=self.tokenizer,
            criterion=pop_criterion, # 'loss'  # in ['loss', 'acc', 'combined']
            topk_strategy=pop_topk_strategy,
            verbose=verbose,
        )
        # Suff to track for early stopping
        self._early_stopping_steps = early_stopping_steps
        self._last_population = None
        self._steps_since_new_population = 0
        ####################################################################
        self.prefix_ids = None
        if len(preprefix_str):
            self.preprefix_ids = torch.tensor(
                self.tokenizer.encode(preprefix_str, add_special_tokens=False), dtype=int
            ).to(device)
        else:
            self.preprefix_ids = torch.tensor([], dtype=int).to(device)
        
        prompt_str = preprefix_str.lstrip()
        prompt_str = (' ' + prompt_str) if len(prompt_str) else ''
        self._pre_data_token_ids = self._pre_data_token_ids = self.tokenizer(
           f"{llm_candidate_regeneration_prompt_start}\n\n", return_tensors='pt').input_ids.to(device)
        self._post_data_token_ids = self.tokenizer(
           f"\n\n{llm_candidate_regeneration_prompt_end}" + prompt_str, return_tensors='pt').input_ids.to(device)
        
        self.llm_candidate_regeneration_prompt_start = llm_candidate_regeneration_prompt_start
        self.llm_candidate_regeneration_prompt_end = llm_candidate_regeneration_prompt_end
        ####################################################################
        self._iprompt_verbose = verbose
        self._step = 0
        
    
    def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]:
        r = super().serialize(eval_dataloader=eval_dataloader, possible_answer_mask=possible_answer_mask)
        r["topk_pop_sample"] = self._topk_pop_sample
        r["pop_size"] = self._pop_size
        r["num_mutations_per_ex"] = self._num_mutations_per_ex
        r["num_random_generations"] = self._num_random_generations
        r["generation_temp"] = self._generation_temp
        r["generation_top_p"] = self._generation_top_p
        r["generation_repetition_penalty"] = self._generation_repetition_penalty
        r["generation_bad_words_ids"] = self._generation_bad_words_ids
        r["pre_data_prompt_str"] = self.tokenizer.decode(self._pre_data_token_ids.flatten())
        r["post_data_prompt_str"] = self.tokenizer.decode(self._post_data_token_ids.flatten())
        return r
    
    def _initialize_pop_once(self, full_text_ids: torch.Tensor):
        if self._pop_initialized: return

        while len(self._prefix_pool) < self._pop_size:
            conditional_input_ids = random.choice(full_text_ids)[None]
            num_conditional_tokens = conditional_input_ids.numel()
            input_ids = self._generate(
                input_ids=conditional_input_ids,
                num_conditional_tokens=num_conditional_tokens
            ).squeeze()
            assert input_ids.numel() == self._num_tokens
            self._prefix_pool.initialize_prefix(input_ids)

        self._pop_initialized = True
    
    @property
    def _generation_model(self) -> transformers.AutoModelForCausalLM:
        """Returns the model to use for generation.

        We optionally support using different models for generation and discrimination.
        However, by default, we use the same model for both.
        """
        if self.other_generation_model:
            return self.other_generation_model
        else:
            return self.model
    
    def _generate(self, input_ids: torch.Tensor, num_conditional_tokens: int) -> torch.Tensor:
        """Generates some text using the model and preset hparams.

        If `num_conditional_tokens` > 0, generates extra text because there was an additional
        prefix set.
        """
        attention_mask = ~(input_ids == self.tokenizer.pad_token_id)
        assert attention_mask.shape == input_ids.shape

        if self._is_t5:
            output_length = self._num_tokens + 1 # will add pad token
        else:
            output_length = self._num_tokens + num_conditional_tokens
        
        # print("iPrompt._generate", input_ids.shape, "//", self.tokenizer.decode(input_ids[0]))
        
        g = self._generation_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            min_length=output_length,
            max_length=output_length,
            temperature=self._generation_temp,
            top_p=self._generation_top_p,
            repetition_penalty=self._generation_repetition_penalty,
            bad_words_ids=self._generation_bad_words_ids,
            do_sample=True
        )
        
        if self._is_t5:
            assert (g[:, 0] == 0).all()
            g = g[:, 1:]
        else:
            # Split off the conditional part, we only want the prefix part, which
            # starts after the conditional part.
            g = g[:, num_conditional_tokens:]

        if self._iprompt_verbose:
            # Print a random one (but remove padded tokens and newlines)
            idx = random.choice(range(len(input_ids)))
            # idx_attention_mask = torch.cat(
            #     (attention_mask[idx], torch.ones(self._num_tokens).to(device)), dim=0
            # ).bool()
            random_sentence_ids = g[idx]
            # print(">>", self.tokenizer.decode(random_sentence_ids).replace('\n', '\\n'))
        
        return g
    
    def _select_pop_topk(self, k: int, min_occurrences: int = None) -> List[Tuple[int]]:
        return self._prefix_pool.topk(k=k, min_occurrences=min_occurrences)

    def _track_early_stopping(self):
        """Track changes in population to tell when to stop early."""
        __n_early_stop = 5
        population = set(self._select_pop_topk(k=__n_early_stop, min_occurrences=3))
        if (len(population) == __n_early_stop) and (self._last_population == population):
            self._steps_since_new_population += 1
            if self._iprompt_verbose:
                print("self._steps_since_new_population:", self._steps_since_new_population)
        else:
            self._last_population = population
            self._steps_since_new_population = 0
            if self._iprompt_verbose:
                print("new population:", [self.tokenizer.decode(p) for p in sorted(population)])

    def check_early_stop(self) -> bool:
        """Allow prefix models to stop early."""
        if self._early_stopping_steps == -1:
            return False
        return self._steps_since_new_population >= self._early_stopping_steps
    
    def _get_population_and_random_generations(self, full_text_ids: torch.Tensor) -> torch.Tensor:
        population_pool = self._select_pop_topk(k=self._topk_pop_sample)
        # if self._iprompt_verbose:
            # print("population_pool:", [self.tokenizer.decode(p) for p in population_pool])
        population = random.sample(population_pool, self._pop_size)
        population = torch.tensor(population).to(device)

        if self._num_random_generations > 0:
            random_idxs = torch.randint(
                low=0, high=len(full_text_ids), size=(self._num_random_generations,)
            )
            random_full_text_ids = full_text_ids[random_idxs]
            num_conditional_tokens = full_text_ids.shape[1]
            random_population = self._generate(
                input_ids=random_full_text_ids,
                num_conditional_tokens=num_conditional_tokens
            )

            full_population = torch.cat((population, random_population), dim=0)
        else:
            # Support case where _num_random_generations is set to 0.
            full_population = population

        assert full_population.shape == (
            self._pop_size + self._num_random_generations,
            self._num_tokens
        )
        return full_population
    
    def _mutate(self, population_input_ids: torch.Tensor, full_text_ids: torch.Tensor) -> List[torch.Tensor]:
        """Mutates a population of prefixes.

        Truncates to a random place and then generates new options
        to try.

        Args:
            population_input_ids (int torch.Tensor): input IDs for each prefix in population
            full_text_ids (int torch.Tensor): input IDs for each data item in the batch. Intended
                be used to do prefix generation conditioned on data
        """
        assert population_input_ids.shape[1] == self._num_tokens
        input_ids = population_input_ids.repeat((self._num_mutations_per_ex, 1))

        self._roll_before_truncation = False
        if self._roll_before_truncation:
            roll_amount = random.randint(0, self._num_tokens-1)
            input_ids = torch.roll(input_ids, roll_amount, dims=[1])

        truncate_position = random.randint(0, self._num_tokens-1)
        truncated_input_ids = input_ids[:, :truncate_position]

        random_idxs = torch.randint(low=0, high=len(full_text_ids), size=(len(input_ids), ))
        random_full_text_ids = full_text_ids[random_idxs]
        conditional_input_ids = torch.cat((random_full_text_ids, truncated_input_ids), dim=1)

        num_conditional_tokens = full_text_ids.shape[1]
        new_input_ids = self._generate(
            input_ids=conditional_input_ids,
            num_conditional_tokens=num_conditional_tokens
        )
        return new_input_ids
    
    def _score_population(
            self, 
            x_tokenized: transformers.BatchEncoding,
            y_tokenized: transformers.BatchEncoding,
            population_input_ids: torch.Tensor,
            possible_answer_mask: torch.Tensor
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Scores a population of prefixes and updates `self._genetic_pool`."""
        pop_size = len(population_input_ids)
        all_candidate_losses = torch.zeros(pop_size, dtype=float).to(device)
        all_accuracy = torch.zeros(pop_size, dtype=float).to(device)
        all_candidate_n_correct = torch.zeros(pop_size, dtype=int).to(device)
        for i in range(pop_size):
            with torch.no_grad():
                _cand_input_ids, cand_loss, cand_n_correct = (
                    self._compute_loss_with_set_prefix(
                        original_input_ids=x_tokenized.input_ids,
                        next_token_ids=y_tokenized.input_ids,
                        possible_answer_mask=possible_answer_mask,
                        prefix_ids=population_input_ids[i],
                    )
                )
                cand_accuracy = cand_n_correct / len(x_tokenized.input_ids)
            all_candidate_n_correct[i] += cand_n_correct
            all_candidate_losses[i] += cand_loss
            all_accuracy[i] += cand_accuracy
        
        for i in range(pop_size):
            new_pop_input_ids = tuple(population_input_ids[i].cpu().tolist())
            assert len(new_pop_input_ids) == (self._num_tokens)
            self._prefix_pool.update(
                population_input_ids[i], all_candidate_losses[i], all_accuracy[i]
            )
        return all_candidate_losses, all_candidate_n_correct
    
    def _create_full_text_ids(
        self, full_text_input_ids: torch.Tensor) -> torch.Tensor:
        """Creates input for generating explanation.

        Takes tokenized inputs (like: "Input: 7 8 Output: 15")
        and makes a full string that looks like "Data:\n\n Input: .... 15 \n\nExplanation:\n\n",
        using whatever template is defined by pre-data and post-data.
        """
        B = len(full_text_input_ids)
        pre_data = self._pre_data_token_ids.repeat((B, 1)).to(device)  # Like "Data:\n\n"
        post_data = self._post_data_token_ids.repeat((B, 1)).to(device) # Like "\n\nPrompt:"
        output = torch.cat((pre_data, full_text_input_ids, post_data), dim=1)
        return output

    def compute_loss_and_call_backward(
            self,
            x_tokenized: transformers.BatchEncoding,
            y_tokenized: transformers.BatchEncoding,
            possible_answer_mask: torch.Tensor,
            full_text_tokenized: Optional[transformers.BatchEncoding] = None
        ) -> Tuple[torch.Tensor, int]:
        """Returns the loss from the best example in the population

        Note: does not call loss.backward()
        
        Returns:
            loss (float torch.Tensor) -- the loss
            num_correct (int): number of examples where prediction was correct
        """
        self.model.eval()

        # allow for conditioning only on x or y. This is mainly just used for ablations.
        if self.conditioning_strategy == "x_only":
            full_text_tokenized = y_tokenized
        elif self.conditioning_strategy == "y_only":
            full_text_tokenized = y_tokenized
        elif self.conditioning_strategy == "unconditional":
            full_text_tokenized['input_ids'] = torch.full(
                size=(len(y_tokenized), 1),
                fill_value=self.tokenizer.bos_token_id,
                device=device,
            )
            full_text_tokenized['attention_mask'] = torch.ones_like(
                full_text_tokenized['input_ids']
            )

        # logic here is that we want to see a sample multiple times before
        # we actually have a good estimate of its loss.
        num_min_occurrences = 2

        full_text_ids = self._create_full_text_ids(
            full_text_input_ids=full_text_tokenized.input_ids,
        )
        self._initialize_pop_once(full_text_ids=full_text_ids)

        prefix_save_folder = os.path.join(self.args.save_dir_unique, 'prefix')
        df_to_print = self._prefix_pool.print(topk=10, min_occurrences=num_min_occurrences)
        os.makedirs(prefix_save_folder, exist_ok=True)

        log_prefixes = False
        if log_prefixes:
            prefix_out_file = os.path.join(prefix_save_folder, f'prefix_{self._step}.p')
            df_to_print.to_pickle(prefix_out_file)
            print(f'wrote {len(df_to_print)} prefixes to {prefix_out_file}')

        # Grab new population
        population_input_ids = self._get_population_and_random_generations(
            full_text_ids=full_text_ids,
        )

        if self._num_mutations_per_ex > 0:
            mutated_population_input_ids = self._mutate(
                population_input_ids=population_input_ids, full_text_ids=full_text_ids
            )
            full_population_input_ids = torch.cat(
                (population_input_ids, mutated_population_input_ids), dim=0
            )
        else:
            # Support skipping mutation step by stetting _num_mutations_per_ex to 0
            full_population_input_ids = population_input_ids

        # Re-score new guys
        all_candidate_losses, all_candidate_n_correct = self._score_population(
            x_tokenized=x_tokenized,
            y_tokenized=y_tokenized,
            population_input_ids=full_population_input_ids,
            possible_answer_mask=possible_answer_mask
        )

        # Track changes in population to enable early stopping.
        self._track_early_stopping()

        # Reset prefix IDs so that the model can be readily used for eval.
        best_prefix_ids = min(self._prefix_pool._avg_loss, key=self._prefix_pool._avg_loss.get)
        self._set_prefix_ids(torch.tensor(best_prefix_ids).to(device))
        self.prefix_embedding.requires_grad = False

        self._step += 1
        return all_candidate_losses.min(), all_candidate_n_correct.max()
        
    def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
        # 
        # Get candidate IDs for every position.
        # 
        pass

Ancestors

Methods

def compute_loss_and_call_backward(self, x_tokenized: transformers.tokenization_utils_base.BatchEncoding, y_tokenized: transformers.tokenization_utils_base.BatchEncoding, possible_answer_mask: torch.Tensor, full_text_tokenized: Optional[transformers.tokenization_utils_base.BatchEncoding] = None) ‑> Tuple[torch.Tensor, int]

Returns the loss from the best example in the population

Note: does not call loss.backward()

Returns

loss (float torch.Tensor) – the loss num_correct (int): number of examples where prediction was correct

Expand source code
def compute_loss_and_call_backward(
        self,
        x_tokenized: transformers.BatchEncoding,
        y_tokenized: transformers.BatchEncoding,
        possible_answer_mask: torch.Tensor,
        full_text_tokenized: Optional[transformers.BatchEncoding] = None
    ) -> Tuple[torch.Tensor, int]:
    """Returns the loss from the best example in the population

    Note: does not call loss.backward()
    
    Returns:
        loss (float torch.Tensor) -- the loss
        num_correct (int): number of examples where prediction was correct
    """
    self.model.eval()

    # allow for conditioning only on x or y. This is mainly just used for ablations.
    if self.conditioning_strategy == "x_only":
        full_text_tokenized = y_tokenized
    elif self.conditioning_strategy == "y_only":
        full_text_tokenized = y_tokenized
    elif self.conditioning_strategy == "unconditional":
        full_text_tokenized['input_ids'] = torch.full(
            size=(len(y_tokenized), 1),
            fill_value=self.tokenizer.bos_token_id,
            device=device,
        )
        full_text_tokenized['attention_mask'] = torch.ones_like(
            full_text_tokenized['input_ids']
        )

    # logic here is that we want to see a sample multiple times before
    # we actually have a good estimate of its loss.
    num_min_occurrences = 2

    full_text_ids = self._create_full_text_ids(
        full_text_input_ids=full_text_tokenized.input_ids,
    )
    self._initialize_pop_once(full_text_ids=full_text_ids)

    prefix_save_folder = os.path.join(self.args.save_dir_unique, 'prefix')
    df_to_print = self._prefix_pool.print(topk=10, min_occurrences=num_min_occurrences)
    os.makedirs(prefix_save_folder, exist_ok=True)

    log_prefixes = False
    if log_prefixes:
        prefix_out_file = os.path.join(prefix_save_folder, f'prefix_{self._step}.p')
        df_to_print.to_pickle(prefix_out_file)
        print(f'wrote {len(df_to_print)} prefixes to {prefix_out_file}')

    # Grab new population
    population_input_ids = self._get_population_and_random_generations(
        full_text_ids=full_text_ids,
    )

    if self._num_mutations_per_ex > 0:
        mutated_population_input_ids = self._mutate(
            population_input_ids=population_input_ids, full_text_ids=full_text_ids
        )
        full_population_input_ids = torch.cat(
            (population_input_ids, mutated_population_input_ids), dim=0
        )
    else:
        # Support skipping mutation step by stetting _num_mutations_per_ex to 0
        full_population_input_ids = population_input_ids

    # Re-score new guys
    all_candidate_losses, all_candidate_n_correct = self._score_population(
        x_tokenized=x_tokenized,
        y_tokenized=y_tokenized,
        population_input_ids=full_population_input_ids,
        possible_answer_mask=possible_answer_mask
    )

    # Track changes in population to enable early stopping.
    self._track_early_stopping()

    # Reset prefix IDs so that the model can be readily used for eval.
    best_prefix_ids = min(self._prefix_pool._avg_loss, key=self._prefix_pool._avg_loss.get)
    self._set_prefix_ids(torch.tensor(best_prefix_ids).to(device))
    self.prefix_embedding.requires_grad = False

    self._step += 1
    return all_candidate_losses.min(), all_candidate_n_correct.max()
def post_epoch(self, dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> None
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
    # 
    # Get candidate IDs for every position.
    # 
    pass

Inherited members