Module imodelsx.iprompt.autoprompt

Expand source code
from typing import Any, Dict, List, Optional, Tuple

import argparse
import functools
import os
import pickle
import random

import pandas as pd
import torch
import tqdm
import transformers

from .hotflip import HotFlip
from .utils import device, PrefixLoss, PrefixModel, PrefixPool


class AutoPrompt(HotFlip):
    args: argparse.Namespace
    loss_func: PrefixLoss
    model: transformers.PreTrainedModel
    tokenizer: transformers.PreTrainedTokenizer
    prefix_ids: torch.Tensor
    prefix_embedding: torch.nn.Parameter
    preprefix: str

    def __init__(
        self,
        args: argparse.Namespace,
        loss_func: PrefixLoss,
        model: transformers.PreTrainedModel,
        tokenizer: transformers.PreTrainedTokenizer,
        preprefix: str = ''
    ):
        super().__init__(
            args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix
        )
        self._do_final_reranking = args.iprompt_do_final_reranking
        # AutoPrompt-specific parameters.
        self._num_candidates_per_prefix_token = 32  # V_cand in autoprompt paper
        # This helps us know which were the best prefixes to return over time
        self._prefix_pool = PrefixPool(
            tokenizer=self.tokenizer,
            criterion='loss'  # in ['loss', 'acc', 'combined']
        )
        self._autoprompt_verbose = True
        self._num_min_occurrences = 1
        # Will rank and save this many prefixes at the end of training.
        self._num_prefixes_to_test = 64
    
    def test_prefixes(
        self,
        prefixes: List[Tuple[int]], 
        eval_dataloader: torch.utils.data.DataLoader, 
        possible_answer_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes loss & accuracy for each prefix on data in dataloader. Used to rank
        prefixes at the end of training.
        """
        all_candidate_losses = torch.zeros(len(prefixes), dtype=torch.float32)
        all_candidate_n_correct = torch.zeros(
            len(prefixes), dtype=torch.float32)
        total_n = 0
        for batch in tqdm.tqdm(eval_dataloader, desc=f'evaluating {len(prefixes)} prefixes'):
            if (self.args.n_shots > 1) and (self.args.single_shot_loss): ##
               batch['input'] = batch['last_input'] ##
            x_text, y_text = self.prepare_batch(batch=batch)
            tok = functools.partial(
                self.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=self.args.max_length  # TODO set max_length on self
            )
            x_tokenized = tok(x_text).to(device)
            y_tokenized = tok(y_text).to(device)
            total_n += len(x_tokenized.input_ids)

            next_token_ids = y_tokenized.input_ids
            for i in range(len(prefixes)):
                with torch.no_grad():
                    _cand_input_ids, cand_loss, cand_n_correct = (
                        self._compute_loss_with_set_prefix(
                            original_input_ids=x_tokenized.input_ids,
                            next_token_ids=next_token_ids,
                            possible_answer_mask=possible_answer_mask,
                            prefix_ids=torch.tensor(prefixes[i]).to(device),
                        )
                    )
                all_candidate_losses[i] += cand_loss.item()
                all_candidate_n_correct[i] += cand_n_correct.item()
        return all_candidate_losses.cpu().tolist(), (all_candidate_n_correct / total_n).cpu().tolist()

    def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]:
        """Writes stuff to disk. Saves other stuff to save as full results file.
        """

        # Uncomment following lines to save all the prefixes we tested.
        # save_dir = self.args.save_dir_unique
        # os.makedirs(save_dir, exist_ok=True)
        # pickle.dump(self._prefix_pool, open(os.path.join(save_dir, 'prefix_pool.p'), 'wb'))

        all_prefixes = self._prefix_pool.topk_all(
            k=self._num_prefixes_to_test, min_occurrences=3)

        if not len(all_prefixes):
            # In the case where we get no prefixes here (i.e. prompt generation
            # only ran for a single step) just take anything from prefix pool.
            all_prefixes = random.choices(list(self._prefix_pool.prefixes), k=self._num_prefixes_to_test)

        if self._do_final_reranking:
            all_losses, all_accuracies = self.test_prefixes(
                prefixes=all_prefixes,
                eval_dataloader=eval_dataloader,
                possible_answer_mask=possible_answer_mask
            )
            df = pd.DataFrame(
                zip(*[all_prefixes, all_losses, all_accuracies]),
                columns=['prefix', 'loss', 'accuracy']
            )
            df = df.sort_values(by=['accuracy', 'loss'], ascending=[
                                False, True]).reset_index()
        else:
            all_prefixes = list(self._prefix_pool.prefixes)
            all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes]
            all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes]

            df = pd.DataFrame(
                zip(*[all_prefixes, all_losses, all_accuracies]),
                columns=['prefix', 'loss', 'accuracy']
            )
        df = df.sort_values(by='accuracy', ascending=False).reset_index()

        df['prefix_str'] = df['prefix'].map(self.tokenizer.decode)
        df['n_queries'] = df['prefix'].map(
            lambda p_ids: len(self._prefix_pool._all_losses[p_ids]))

        print('Final prefixes')
        print(df.head())

        return {
            "prefix_ids": df['prefix'].tolist(),
            "prefixes": df['prefix_str'].tolist(),
            "prefix_train_acc": df['accuracy'].tolist(),
            "prefix_train_loss": df['loss'].tolist(),
            "prefix_n_queries": df['n_queries'].tolist(),
        }

    def compute_loss_and_call_backward(
        self,
        x_tokenized: transformers.BatchEncoding,
        y_tokenized: transformers.BatchEncoding,
        possible_answer_mask: torch.Tensor,
        full_text_tokenized: Optional[transformers.BatchEncoding] = None
    ) -> Tuple[torch.Tensor, int]:
        """Computes loss using `self.loss_func`.

        Returns:
            loss (float torch.Tensor) -- the loss
            num_correct (int): number of examples where prediction was correct
        """
        original_input_ids = x_tokenized.input_ids
        next_token_ids = y_tokenized.input_ids  # only compute loss over next token

        current_input_ids, current_loss, current_n_correct = self._compute_loss_with_set_prefix(
            original_input_ids=original_input_ids,
            next_token_ids=next_token_ids,
            possible_answer_mask=possible_answer_mask,
            prefix_ids=None,
        )
        current_loss.backward()

        self._autoprompt_verbose: print(
            f'** {self.tokenizer.decode(self.prefix_ids)}: {current_loss:.2f}')
        
        # track running accuracy of this prefix.
        self._prefix_pool.update(
            prefix=self.prefix_ids,
            loss=current_loss,
            accuracy=(current_n_correct/len(original_input_ids))
        )

        # print an update.
        self._prefix_pool.print(topk=10, min_occurrences=1)

        #
        # Get top token replacements
        #
        token_grads = self._prefix_token_grad
        if self._is_t5:
            # t5 has extra vocab tokens for no reason:
            # https://github.com/huggingface/transformers/issues/4875#issuecomment-647634437
            assert token_grads.shape == (
                self._num_tokens, len(self.tokenizer.vocab) + 28
            )
            token_grads = token_grads[:, :-28]
        assert token_grads.shape == (
            self._num_tokens, len(self.tokenizer.vocab))
        top_tokens_per_position = (
            token_grads.topk(
                k=self._num_candidates_per_prefix_token, dim=1, largest=False).indices
        )
        assert top_tokens_per_position.shape == (
            self._num_tokens, self._num_candidates_per_prefix_token)

        top_swap_tokens = top_tokens_per_position[self._swap_token_idx, :]
        #
        # Get most likely tokens.
        #
        top_swap_tokens = token_grads.argsort(descending=False).flatten()
        top_swap_tokens = top_swap_tokens[0:
                                          self._num_candidates_per_prefix_token]

        # rank candidates
        mask = torch.nn.functional.one_hot(
            torch.tensor(self._swap_token_idx), num_classes=self._num_tokens
        ).bool().to(device)
        candidate_prefix_ids = torch.where(
            mask, top_swap_tokens[:, None], self.prefix_ids[None, :])
        is_current_prefix_mask = (
            candidate_prefix_ids == self.prefix_ids).all(dim=1)
        candidate_prefix_ids = candidate_prefix_ids[~is_current_prefix_mask]

        # get best prefix
        num_candidates = len(candidate_prefix_ids)
        all_candidate_losses = torch.zeros(
            num_candidates, dtype=float).to(device)
        all_n_correct = torch.zeros(num_candidates, dtype=int).to(device)
        for i in range(num_candidates):
            with torch.no_grad():
                cand_input_ids, cand_loss, cand_n_correct = (
                    self._compute_loss_with_set_prefix(
                        original_input_ids=original_input_ids,
                        next_token_ids=next_token_ids,
                        possible_answer_mask=possible_answer_mask,
                        prefix_ids=candidate_prefix_ids[i],
                    )
                )
            all_candidate_losses[i] = cand_loss
            all_n_correct[i] = cand_n_correct

            # self._autoprompt_verbose: print(
                # f'** \t{self.tokenizer.decode(candidate_prefix_ids[i])}: {cand_loss:.2f}')

            self._prefix_pool.update(
                prefix=candidate_prefix_ids[i],
                loss=cand_loss,
                accuracy=(cand_n_correct / len(original_input_ids))
            )

        # randomly change the token to swap
        self._swap_token_idx = random.randint(0, (self._num_tokens-1))
        # get best prefix we've seen
        if all_candidate_losses.min() < current_loss:
            best_prefix = candidate_prefix_ids[all_candidate_losses.argmin()]
            best_prefix_loss = all_candidate_losses.min()
            best_prefix_n_correct = all_n_correct[all_candidate_losses.argmin(
            )]
            if self._autoprompt_verbose:
                print("** set new prefix", best_prefix)
        else:
            best_prefix = self.prefix_ids
            best_prefix_loss = current_loss
            best_prefix_n_correct = current_n_correct
            if self._autoprompt_verbose:
                print("** set same prefix", best_prefix)

        self._set_prefix_ids(best_prefix)
        return best_prefix_loss, best_prefix_n_correct

    def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
        #
        # Get candidate IDs for every position.
        #
        pass

Classes

class AutoPrompt (args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.modeling_utils.PreTrainedModel, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, preprefix: str = '')

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class AutoPrompt(HotFlip):
    args: argparse.Namespace
    loss_func: PrefixLoss
    model: transformers.PreTrainedModel
    tokenizer: transformers.PreTrainedTokenizer
    prefix_ids: torch.Tensor
    prefix_embedding: torch.nn.Parameter
    preprefix: str

    def __init__(
        self,
        args: argparse.Namespace,
        loss_func: PrefixLoss,
        model: transformers.PreTrainedModel,
        tokenizer: transformers.PreTrainedTokenizer,
        preprefix: str = ''
    ):
        super().__init__(
            args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix
        )
        self._do_final_reranking = args.iprompt_do_final_reranking
        # AutoPrompt-specific parameters.
        self._num_candidates_per_prefix_token = 32  # V_cand in autoprompt paper
        # This helps us know which were the best prefixes to return over time
        self._prefix_pool = PrefixPool(
            tokenizer=self.tokenizer,
            criterion='loss'  # in ['loss', 'acc', 'combined']
        )
        self._autoprompt_verbose = True
        self._num_min_occurrences = 1
        # Will rank and save this many prefixes at the end of training.
        self._num_prefixes_to_test = 64
    
    def test_prefixes(
        self,
        prefixes: List[Tuple[int]], 
        eval_dataloader: torch.utils.data.DataLoader, 
        possible_answer_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes loss & accuracy for each prefix on data in dataloader. Used to rank
        prefixes at the end of training.
        """
        all_candidate_losses = torch.zeros(len(prefixes), dtype=torch.float32)
        all_candidate_n_correct = torch.zeros(
            len(prefixes), dtype=torch.float32)
        total_n = 0
        for batch in tqdm.tqdm(eval_dataloader, desc=f'evaluating {len(prefixes)} prefixes'):
            if (self.args.n_shots > 1) and (self.args.single_shot_loss): ##
               batch['input'] = batch['last_input'] ##
            x_text, y_text = self.prepare_batch(batch=batch)
            tok = functools.partial(
                self.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=self.args.max_length  # TODO set max_length on self
            )
            x_tokenized = tok(x_text).to(device)
            y_tokenized = tok(y_text).to(device)
            total_n += len(x_tokenized.input_ids)

            next_token_ids = y_tokenized.input_ids
            for i in range(len(prefixes)):
                with torch.no_grad():
                    _cand_input_ids, cand_loss, cand_n_correct = (
                        self._compute_loss_with_set_prefix(
                            original_input_ids=x_tokenized.input_ids,
                            next_token_ids=next_token_ids,
                            possible_answer_mask=possible_answer_mask,
                            prefix_ids=torch.tensor(prefixes[i]).to(device),
                        )
                    )
                all_candidate_losses[i] += cand_loss.item()
                all_candidate_n_correct[i] += cand_n_correct.item()
        return all_candidate_losses.cpu().tolist(), (all_candidate_n_correct / total_n).cpu().tolist()

    def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]:
        """Writes stuff to disk. Saves other stuff to save as full results file.
        """

        # Uncomment following lines to save all the prefixes we tested.
        # save_dir = self.args.save_dir_unique
        # os.makedirs(save_dir, exist_ok=True)
        # pickle.dump(self._prefix_pool, open(os.path.join(save_dir, 'prefix_pool.p'), 'wb'))

        all_prefixes = self._prefix_pool.topk_all(
            k=self._num_prefixes_to_test, min_occurrences=3)

        if not len(all_prefixes):
            # In the case where we get no prefixes here (i.e. prompt generation
            # only ran for a single step) just take anything from prefix pool.
            all_prefixes = random.choices(list(self._prefix_pool.prefixes), k=self._num_prefixes_to_test)

        if self._do_final_reranking:
            all_losses, all_accuracies = self.test_prefixes(
                prefixes=all_prefixes,
                eval_dataloader=eval_dataloader,
                possible_answer_mask=possible_answer_mask
            )
            df = pd.DataFrame(
                zip(*[all_prefixes, all_losses, all_accuracies]),
                columns=['prefix', 'loss', 'accuracy']
            )
            df = df.sort_values(by=['accuracy', 'loss'], ascending=[
                                False, True]).reset_index()
        else:
            all_prefixes = list(self._prefix_pool.prefixes)
            all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes]
            all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes]

            df = pd.DataFrame(
                zip(*[all_prefixes, all_losses, all_accuracies]),
                columns=['prefix', 'loss', 'accuracy']
            )
        df = df.sort_values(by='accuracy', ascending=False).reset_index()

        df['prefix_str'] = df['prefix'].map(self.tokenizer.decode)
        df['n_queries'] = df['prefix'].map(
            lambda p_ids: len(self._prefix_pool._all_losses[p_ids]))

        print('Final prefixes')
        print(df.head())

        return {
            "prefix_ids": df['prefix'].tolist(),
            "prefixes": df['prefix_str'].tolist(),
            "prefix_train_acc": df['accuracy'].tolist(),
            "prefix_train_loss": df['loss'].tolist(),
            "prefix_n_queries": df['n_queries'].tolist(),
        }

    def compute_loss_and_call_backward(
        self,
        x_tokenized: transformers.BatchEncoding,
        y_tokenized: transformers.BatchEncoding,
        possible_answer_mask: torch.Tensor,
        full_text_tokenized: Optional[transformers.BatchEncoding] = None
    ) -> Tuple[torch.Tensor, int]:
        """Computes loss using `self.loss_func`.

        Returns:
            loss (float torch.Tensor) -- the loss
            num_correct (int): number of examples where prediction was correct
        """
        original_input_ids = x_tokenized.input_ids
        next_token_ids = y_tokenized.input_ids  # only compute loss over next token

        current_input_ids, current_loss, current_n_correct = self._compute_loss_with_set_prefix(
            original_input_ids=original_input_ids,
            next_token_ids=next_token_ids,
            possible_answer_mask=possible_answer_mask,
            prefix_ids=None,
        )
        current_loss.backward()

        self._autoprompt_verbose: print(
            f'** {self.tokenizer.decode(self.prefix_ids)}: {current_loss:.2f}')
        
        # track running accuracy of this prefix.
        self._prefix_pool.update(
            prefix=self.prefix_ids,
            loss=current_loss,
            accuracy=(current_n_correct/len(original_input_ids))
        )

        # print an update.
        self._prefix_pool.print(topk=10, min_occurrences=1)

        #
        # Get top token replacements
        #
        token_grads = self._prefix_token_grad
        if self._is_t5:
            # t5 has extra vocab tokens for no reason:
            # https://github.com/huggingface/transformers/issues/4875#issuecomment-647634437
            assert token_grads.shape == (
                self._num_tokens, len(self.tokenizer.vocab) + 28
            )
            token_grads = token_grads[:, :-28]
        assert token_grads.shape == (
            self._num_tokens, len(self.tokenizer.vocab))
        top_tokens_per_position = (
            token_grads.topk(
                k=self._num_candidates_per_prefix_token, dim=1, largest=False).indices
        )
        assert top_tokens_per_position.shape == (
            self._num_tokens, self._num_candidates_per_prefix_token)

        top_swap_tokens = top_tokens_per_position[self._swap_token_idx, :]
        #
        # Get most likely tokens.
        #
        top_swap_tokens = token_grads.argsort(descending=False).flatten()
        top_swap_tokens = top_swap_tokens[0:
                                          self._num_candidates_per_prefix_token]

        # rank candidates
        mask = torch.nn.functional.one_hot(
            torch.tensor(self._swap_token_idx), num_classes=self._num_tokens
        ).bool().to(device)
        candidate_prefix_ids = torch.where(
            mask, top_swap_tokens[:, None], self.prefix_ids[None, :])
        is_current_prefix_mask = (
            candidate_prefix_ids == self.prefix_ids).all(dim=1)
        candidate_prefix_ids = candidate_prefix_ids[~is_current_prefix_mask]

        # get best prefix
        num_candidates = len(candidate_prefix_ids)
        all_candidate_losses = torch.zeros(
            num_candidates, dtype=float).to(device)
        all_n_correct = torch.zeros(num_candidates, dtype=int).to(device)
        for i in range(num_candidates):
            with torch.no_grad():
                cand_input_ids, cand_loss, cand_n_correct = (
                    self._compute_loss_with_set_prefix(
                        original_input_ids=original_input_ids,
                        next_token_ids=next_token_ids,
                        possible_answer_mask=possible_answer_mask,
                        prefix_ids=candidate_prefix_ids[i],
                    )
                )
            all_candidate_losses[i] = cand_loss
            all_n_correct[i] = cand_n_correct

            # self._autoprompt_verbose: print(
                # f'** \t{self.tokenizer.decode(candidate_prefix_ids[i])}: {cand_loss:.2f}')

            self._prefix_pool.update(
                prefix=candidate_prefix_ids[i],
                loss=cand_loss,
                accuracy=(cand_n_correct / len(original_input_ids))
            )

        # randomly change the token to swap
        self._swap_token_idx = random.randint(0, (self._num_tokens-1))
        # get best prefix we've seen
        if all_candidate_losses.min() < current_loss:
            best_prefix = candidate_prefix_ids[all_candidate_losses.argmin()]
            best_prefix_loss = all_candidate_losses.min()
            best_prefix_n_correct = all_n_correct[all_candidate_losses.argmin(
            )]
            if self._autoprompt_verbose:
                print("** set new prefix", best_prefix)
        else:
            best_prefix = self.prefix_ids
            best_prefix_loss = current_loss
            best_prefix_n_correct = current_n_correct
            if self._autoprompt_verbose:
                print("** set same prefix", best_prefix)

        self._set_prefix_ids(best_prefix)
        return best_prefix_loss, best_prefix_n_correct

    def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
        #
        # Get candidate IDs for every position.
        #
        pass

Ancestors

Subclasses

Class variables

var args : argparse.Namespace
var loss_funcPrefixLoss
var model : transformers.modeling_utils.PreTrainedModel
var prefix_embedding : torch.nn.parameter.Parameter
var prefix_ids : torch.Tensor
var preprefix : str
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer

Methods

def post_epoch(self, dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> None
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
    #
    # Get candidate IDs for every position.
    #
    pass
def serialize(self, eval_dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> Dict[str, Any]

Writes stuff to disk. Saves other stuff to save as full results file.

Expand source code
def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]:
    """Writes stuff to disk. Saves other stuff to save as full results file.
    """

    # Uncomment following lines to save all the prefixes we tested.
    # save_dir = self.args.save_dir_unique
    # os.makedirs(save_dir, exist_ok=True)
    # pickle.dump(self._prefix_pool, open(os.path.join(save_dir, 'prefix_pool.p'), 'wb'))

    all_prefixes = self._prefix_pool.topk_all(
        k=self._num_prefixes_to_test, min_occurrences=3)

    if not len(all_prefixes):
        # In the case where we get no prefixes here (i.e. prompt generation
        # only ran for a single step) just take anything from prefix pool.
        all_prefixes = random.choices(list(self._prefix_pool.prefixes), k=self._num_prefixes_to_test)

    if self._do_final_reranking:
        all_losses, all_accuracies = self.test_prefixes(
            prefixes=all_prefixes,
            eval_dataloader=eval_dataloader,
            possible_answer_mask=possible_answer_mask
        )
        df = pd.DataFrame(
            zip(*[all_prefixes, all_losses, all_accuracies]),
            columns=['prefix', 'loss', 'accuracy']
        )
        df = df.sort_values(by=['accuracy', 'loss'], ascending=[
                            False, True]).reset_index()
    else:
        all_prefixes = list(self._prefix_pool.prefixes)
        all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes]
        all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes]

        df = pd.DataFrame(
            zip(*[all_prefixes, all_losses, all_accuracies]),
            columns=['prefix', 'loss', 'accuracy']
        )
    df = df.sort_values(by='accuracy', ascending=False).reset_index()

    df['prefix_str'] = df['prefix'].map(self.tokenizer.decode)
    df['n_queries'] = df['prefix'].map(
        lambda p_ids: len(self._prefix_pool._all_losses[p_ids]))

    print('Final prefixes')
    print(df.head())

    return {
        "prefix_ids": df['prefix'].tolist(),
        "prefixes": df['prefix_str'].tolist(),
        "prefix_train_acc": df['accuracy'].tolist(),
        "prefix_train_loss": df['loss'].tolist(),
        "prefix_n_queries": df['n_queries'].tolist(),
    }
def test_prefixes(self, prefixes: List[Tuple[int]], eval_dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor]

Computes loss & accuracy for each prefix on data in dataloader. Used to rank prefixes at the end of training.

Expand source code
def test_prefixes(
    self,
    prefixes: List[Tuple[int]], 
    eval_dataloader: torch.utils.data.DataLoader, 
    possible_answer_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Computes loss & accuracy for each prefix on data in dataloader. Used to rank
    prefixes at the end of training.
    """
    all_candidate_losses = torch.zeros(len(prefixes), dtype=torch.float32)
    all_candidate_n_correct = torch.zeros(
        len(prefixes), dtype=torch.float32)
    total_n = 0
    for batch in tqdm.tqdm(eval_dataloader, desc=f'evaluating {len(prefixes)} prefixes'):
        if (self.args.n_shots > 1) and (self.args.single_shot_loss): ##
           batch['input'] = batch['last_input'] ##
        x_text, y_text = self.prepare_batch(batch=batch)
        tok = functools.partial(
            self.tokenizer, return_tensors='pt', padding='longest',
            truncation=True, max_length=self.args.max_length  # TODO set max_length on self
        )
        x_tokenized = tok(x_text).to(device)
        y_tokenized = tok(y_text).to(device)
        total_n += len(x_tokenized.input_ids)

        next_token_ids = y_tokenized.input_ids
        for i in range(len(prefixes)):
            with torch.no_grad():
                _cand_input_ids, cand_loss, cand_n_correct = (
                    self._compute_loss_with_set_prefix(
                        original_input_ids=x_tokenized.input_ids,
                        next_token_ids=next_token_ids,
                        possible_answer_mask=possible_answer_mask,
                        prefix_ids=torch.tensor(prefixes[i]).to(device),
                    )
                )
            all_candidate_losses[i] += cand_loss.item()
            all_candidate_n_correct[i] += cand_n_correct.item()
    return all_candidate_losses.cpu().tolist(), (all_candidate_n_correct / total_n).cpu().tolist()

Inherited members