Module imodelsx.iprompt.utils
Expand source code
from typing import Any, Dict, List, Iterable, Optional, Tuple, Union
import abc
import argparse
import collections
import dataclasses
import functools
import heapq
import random
import pandas as pd
import transformers
import torch
from torch.utils.data import DataLoader
from torch import nn
import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEBUG_VERBOSE = False
def get_token_replacements_single_mask(
dataloader: DataLoader, model: transformers.AutoModelForMaskedLM,
tokenizer: transformers.AutoTokenizer, init_prefix_template: str, num_candidates: int
)-> List[str]:
"""Given a template like `{mask} the numbers`, returns the `num_candidates` most likely
single-token replacements for `{mask}` given `model`.
"""
single_mask_prefix_str = init_prefix_template.format(mask=tokenizer.mask_token)
all_mask_probs = torch.zeros((tokenizer.vocab_size,), dtype=float).to(device)
for idx, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataloader)):
full_text = [f'{single_mask_prefix_str} {input_text}' for input_text in batch['text']]
if idx == 0:
print('Sample input: ', full_text[0])
inputs = tokenizer(full_text, return_tensors='pt', padding='longest')
with torch.no_grad():
outputs = model(**inputs.to(device))
mask_idxs = (inputs['input_ids'] == tokenizer.mask_token_id).nonzero()
# TODO: how to do this better in torch?
mask_probs = outputs.logits[mask_idxs[:, 0], mask_idxs[:, 1]].log_softmax(dim=1)
all_mask_probs += mask_probs.sum(dim=0)
prefix_idxs = all_mask_probs.topk(num_candidates).indices
return [init_prefix_template.format(mask=tokenizer.decode(idx)) for idx in prefix_idxs]
def get_prefix_from_mlm(
dataloader: DataLoader,
mlm_name: str,
num_candidates: int,
template: str
) -> List[str]:
""" Getting prefix from MLM."""
mlm = transformers.RobertaForMaskedLM.from_pretrained(mlm_name).to(device)
mlm_tokenizer = transformers.AutoTokenizer.from_pretrained(mlm_name)
# template = "{mask} the two numbers to get the answer."
# template = "{mask} the input number to get the answer."
# template = "Return the{mask} of the input."
candidates = get_token_replacements_single_mask(
dataloader=dataloader,
model=mlm, tokenizer=mlm_tokenizer,
init_prefix_template=template,
num_candidates=num_candidates
)
mlm.to('cpu') # no need for mlm on GPU anymore
return candidates
def compute_log_ppl_loss(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
"""Computes LM perplexity loss given logits for next tokens and original input IDs.
Exponentiate this quantity if you want the actual perplexity.
"""
# logits gives us the probability of each token that comes after each token in input_ids.
# so they have the same shape. But we only want to compute ppl using the tokens we have,
# i.e. not the first true token (which we don't have logits for) or the last predicted token
# (which we don't know the true id for). so we have to shift each by one index.
assert logits.shape[0:2] == input_ids.shape
logits = logits[:, :-1, :]
input_ids = input_ids[:, 1:]
# now flatten along sequence length so we can compute crossentropy.
batch_size, sequence_length, vocab_size = logits.shape
assert input_ids.shape == (batch_size, sequence_length)
logits = logits.reshape((batch_size * sequence_length, vocab_size))
input_ids = input_ids.reshape((batch_size * sequence_length, ))
loss = torch.nn.functional.cross_entropy(
input=logits,
target=input_ids,
reduction='mean'
)
return loss
@dataclasses.dataclass
class PrefixLoss:
"""Computes next-token-prediction loss with optional language modeling component.
"""
gamma: float
tokenizer: transformers.PreTrainedTokenizer # for debugging
def _compute_fluency_loss(
self, logits: torch.Tensor, input_ids: torch.Tensor
) -> torch.Tensor:
if self.gamma == 0:
return torch.tensor(0.0).to(device)
return compute_log_ppl_loss(logits=logits, input_ids=input_ids)
def _compute_token_loss(
self, next_token_logits: torch.Tensor, next_token_idxs: torch.Tensor, answer_mask: torch.Tensor
) -> torch.Tensor:
batch_size, vocab_size = next_token_logits.shape
assert next_token_idxs.shape == (batch_size,)
if answer_mask is not None:
assert answer_mask.shape == (vocab_size,)
next_token_logits = torch.where(
answer_mask[None],
next_token_logits, torch.tensor(float('-inf')).to(device)
)
return torch.nn.functional.cross_entropy(
input=next_token_logits,
target=next_token_idxs,
reduction='mean'
)
def __call__(
self,
input_ids: torch.Tensor,
next_token_ids: torch.Tensor,
logits: torch.Tensor,
answer_mask: torch.Tensor,
) -> torch.Tensor:
"""Computes loss.
Args:
input_ids (int torch.Tensor): array of token IDs for inputs
next_token_ids (int torch.Tensor): array of token IDs for the word
that comes after the input
logits (float torch.Tensor): logits for all output tokens, including
the next one
answer_mask (bool torch.Tensor): mask over tokens to remove irrelevant ones
Returns: float torch.Tensor scalar, loss value (lower is better).
"""
fluency_loss = (
self._compute_fluency_loss(
logits=logits,
input_ids=input_ids
)
)
token_loss = (
self._compute_token_loss(
next_token_logits=logits[:, -1, :],
next_token_idxs=next_token_ids,
answer_mask=answer_mask,
)
)
loss = token_loss + (self.gamma * fluency_loss)
if DEBUG_VERBOSE:
print(f">> loss for input string: {self.tokenizer.decode(input_ids[0])}")
print(f"\tLoss = {loss:.3f}")
return loss
class PrefixModel(nn.Module, abc.ABC):
args: argparse.Namespace
loss_func: PrefixLoss
model: transformers.PreTrainedModel
tokenizer: transformers.PreTrainedTokenizer
def __init__(self, args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix: str):
super().__init__()
self.args = args
self.loss_func = loss_func
self.model = model
self.tokenizer = tokenizer
@property
def id_to_word(self) -> Dict[int, str]:
# track token-to-word mapping
return {num: word for word, num in self.tokenizer.vocab.items()}
@property
def _is_gpt_neox(self) -> bool:
return isinstance(self.model, transformers.GPTNeoXModel) or isinstance(self.model, transformers.GPTNeoXForCausalLM)
@property
def _is_t5(self) -> bool:
return isinstance(self.model, transformers.T5ForConditionalGeneration)
@property
def _is_opt(self) -> bool:
return isinstance(self.model, transformers.OPTForCausalLM)
@property
def transformer(self) -> nn.Module:
if self._is_gpt_neox:
return self.model._modules['gpt_neox']
elif self._is_t5:
return self.model.encoder
elif self._is_opt:
return self.model._modules['model'].decoder
else:
return self.model._modules['transformer']
@property
def token_embedding(self) -> nn.Embedding:
if self._is_gpt_neox:
return self.transformer.embed_in
elif self._is_t5:
return self.model.encoder.embed_tokens
elif self._is_opt:
return self.transformer.embed_tokens
else:
return self.transformer.wte
@property
def vocab_size(self) -> int:
return self.token_embedding.weight.shape[0] # 50_257 for GPT2
@property
def token_embedding_dim(self) -> int:
return self.token_embedding.weight.shape[1] # often 768, or 2560 for some larger models
def prepare_batch(self, batch: Dict[str, str]) -> Tuple[str, str]:
"""Preprocesses text from `batch['input']` and `batch['output']` for inputting into prefix model.
"""
if self.prefix_before_input:
x_text = [f'. {prompt}' for prompt in batch['input']]
y_text = [answer for answer in batch['output']] # strip whitespace at the end.
else:
x_text = [prompt for prompt in batch['input']]
y_text = [answer.rstrip().rstrip('.') for answer in batch['output']] # strip whitespace at the end.
return x_text, y_text
def forward(
self,
input_ids: torch.Tensor,
prefix_ids: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
def pre_epoch(self) -> None:
return
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
return
def compute_metrics(self) -> Dict[str, Any]:
return {}
def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]:
"""Writes stuff to disk after training."""
return {}
@abc.abstractproperty
def trainable_params(self) -> Iterable[nn.Parameter]:
raise NotImplementedError()
@abc.abstractmethod
def embed_input_ids(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""To be implemented by subclasses -- embeds input ids and includes some sort of prefix,
for example, in the case of prompt-tuning, by prepending a continuous embedding.
"""
raise NotImplementedError()
def init_continuous_prefix(self, num_tokens: int) -> nn.Parameter:
return nn.Parameter(
self.token_embedding.weight.mean(dim=0, keepdim=True)[None].repeat(1, num_tokens, 1), requires_grad=True
)
def init_discrete_prefix(self, num_tokens: int) -> nn.Parameter:
if self.args.autoprompt_init_strategy == 'random':
return torch.randint(low=0, high=self.tokenizer.vocab_size, size=(num_tokens,))
else:
start_word_id = torch.tensor([self.tokenizer.vocab['the']], dtype=int)
print(f"start_word_id = {start_word_id}")
return start_word_id.repeat((num_tokens,))
def _compute_loss_with_set_prefix(
self,
original_input_ids: torch.Tensor,
next_token_ids: torch.Tensor,
possible_answer_mask: torch.Tensor,
prefix_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
# feed into the model. prefix-handling is implemented in PrefixModel::forward.
# and huggingface LM automatically computes language modeling loss.
if self._is_t5:
assert possible_answer_mask is None, "not implemented with t5 yet"
blank_next_token_ids = torch.zeros(
(len(original_input_ids), 0), dtype=torch.long, device=device)
new_input_ids, embeddings = self.embed_input_ids(
input_ids=original_input_ids,
next_token_ids=blank_next_token_ids,
prefix_ids=prefix_ids,
)
attention_mask = ~(new_input_ids == self.tokenizer.pad_token_id)
outputs = self.model(
inputs_embeds=embeddings,
attention_mask=attention_mask,
labels=next_token_ids
)
next_token_logits = outputs.logits
loss = outputs.loss
else:
new_input_ids, embeddings = self.embed_input_ids(
input_ids=original_input_ids,
next_token_ids=next_token_ids,
prefix_ids=prefix_ids,
)
attention_mask = ~(new_input_ids == self.tokenizer.pad_token_id)
# mask labels before feeding into model
# huggingface supports labels of -100.
# see huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2DoubleHeadsModel.forward.labels
S = new_input_ids.shape[1]
LS = next_token_ids.shape[1]
labels = torch.where(
torch.arange(S, device=device) < (S - LS), -100, new_input_ids
)
labels = torch.where(
labels == self.tokenizer.pad_token_id, -100, labels
)
outputs = self.model(
inputs_embeds=embeddings,
attention_mask=attention_mask,
labels=labels,
)
next_token_logits = outputs.logits[:, -LS-1:-1]
if possible_answer_mask is not None:
BIG_NEGATIVE_NUMBER = torch.tensor(-10**10, dtype=next_token_logits.dtype, device=device)
next_token_logits = torch.where(possible_answer_mask, next_token_logits, BIG_NEGATIVE_NUMBER)
B, S, _V = next_token_logits.shape
loss = torch.nn.functional.cross_entropy(
input=next_token_logits.reshape((B * S, -1)),
target=next_token_ids.reshape((B * S),),
ignore_index=self.tokenizer.pad_token_id,
# reduction=None
)
n_correct = (
(next_token_logits.argmax(dim=-1) == next_token_ids)
|
(self.tokenizer.pad_token_id == next_token_ids)
).all(dim=1).sum()
if DEBUG_VERBOSE:
print(f">> loss for input string: {self.tokenizer.decode(new_input_ids[0])}")
print(f"\tLoss = {outputs.loss:.3f}")
return new_input_ids, loss, n_correct
def compute_loss_and_call_backward(
self,
x_tokenized: transformers.BatchEncoding,
y_tokenized: transformers.BatchEncoding,
possible_answer_mask: Optional[torch.Tensor],
full_text_tokenized: Optional[transformers.BatchEncoding] = None
) -> Tuple[torch.Tensor, int]:
"""Computes loss using `self.loss_func`.
Returns:
loss (float torch.Tensor) -- the loss
num_correct (int): number of examples where prediction was correct
"""
original_input_ids = x_tokenized.input_ids
next_token_ids = y_tokenized.input_ids[:, 0] # only compute loss over next token
input_ids, outputs = self.forward(input_ids=original_input_ids, prefix_ids=None)
next_token_logits = outputs.logits[:, -1, :]
n_correct = (
next_token_logits.argmax(dim=-1)
==
next_token_ids
).int().sum()
loss = self.loss_func(
input_ids=input_ids,
next_token_ids=next_token_ids,
logits=outputs['logits'],
answer_mask=possible_answer_mask
)
loss.backward()
return loss, n_correct
def check_early_stop(self) -> bool:
"""Allow prefix models to stop early."""
return False
def mean(_list: List[Union[int, float]]) -> float:
return sum(_list) / len(_list)
def get_preprefix_from_args(args: argparse.Namespace) -> str:
preprefix = ''
if args.use_preprefix or not args.iprompt_preprefix_str == '':
if args.iprompt_preprefix_str == '':
preprefix = data.get_init_suffix(
args.task_name, args.use_generic_query, args.template_num_init_string)
else:
preprefix = args.iprompt_preprefix_str
return preprefix
def load_lm_from_checkpoint(
checkpoint: str, float16: bool) -> transformers.AutoModel:
print(f"loading lm '{checkpoint}'")
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
llm_cls = transformers.AutoModelForSeq2SeqLM if 't5' in checkpoint else transformers.AutoModelForCausalLM
if float16:
if checkpoint == "EleutherAI/gpt-j-6B":
print(f"loading {checkpoint} in float16...")
lm = llm_cls.from_pretrained(
checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id,
revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
else:
# (only certain models are pre-float16ed)
print(f"trying to convert {checkpoint} to float16...")
lm = llm_cls.from_pretrained(
checkpoint,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
# lm = lm.half()
else:
lm = llm_cls.from_pretrained(
checkpoint,
output_hidden_states=False,
pad_token_id=tokenizer.eos_token_id,
device_map="auto",
# low_cpu_mem_usage=True
)
return lm
class PrefixPool:
"""Tracks a pool of candidate prefixes and their associated metrics over time."""
criterion: str
tokenizer: transformers.PreTrainedTokenizer
verbose: bool
#
_all_losses: Dict[Tuple[int], List[float]]
_avg_loss: Dict[Tuple[int], float]
_all_accuracy: Dict[Tuple[int], List[float]]
_avg_accuracy: Dict[Tuple[int], float]
_best_prefix_by_start_token: Dict[int, Tuple[Tuple[int], float]]
def __init__(self, tokenizer: transformers.PreTrainedTokenizer, criterion: str,
topk_strategy: str = 'different_start_token', verbose: bool = False):
self.tokenizer = tokenizer
self.criterion = criterion
# tuple (input_ids) -> float (loss)
self._avg_loss = {}
self._all_losses = collections.defaultdict(list)
# tuple (input_ids) -> int (n_correct)
self._avg_accuracy = {}
self._all_accuracy = collections.defaultdict(list)
#
self._best_prefix_by_start_token = {}
#
self._topk_strategy = topk_strategy # ['different_start_token', 'all']
self.verbose = verbose
@property
def prefixes(self) -> List[Tuple[int]]:
return self._avg_loss.keys()
@property
def num_start_tokens(self) -> int:
"""Number of different start tokens seen across all prefixes."""
return len(self._best_prefix_by_start_token.keys())
def print(self, topk: int, min_occurrences: int = 2) -> pd.DataFrame:
top_token_ids = self.topk(k=topk, min_occurrences=min_occurrences)
########################### Debugging code ##########################
# import pandas as pd
# vd = pd.DataFrame(self._avg_loss.items(), columns=['prefix', 'loss'])
# vd['prefix_str'] = vd['prefix'].map(self.tokenizer.decode)
# vd['n'] = vd['prefix'].map(lambda p: len(self._all_losses[p]))
# vd.sort_values(by='loss')["prefix_str"].iloc[:25]
# vd.sort_values(by=['n', 'loss'], ascending=[False, True])[["n", "prefix_str"]].iloc[:25]
#####################################################################
if not len(top_token_ids): return
print_str = " ".join(((" " * 45), ("*" * 20), "Population", ("*" * 20))) + "\n"
output_rows = []
for idx, token_ids in enumerate(top_token_ids):
prefix = self.tokenizer.decode(list(token_ids))
loss = self._avg_loss[token_ids]
acc = self._avg_accuracy[token_ids]
prefix_str = "{:>65}".format(prefix.replace("\n", "\\\\n"))
loss_str = f"{loss:.3f}"
acc_str = f"{acc*100:.1f}"
print_str += " ".join((prefix_str, "\t\t", loss_str, "\t\t", acc_str)) + "\n"
output_rows.append([idx, prefix, loss, acc])
if self.verbose:
print(print_str)
return pd.DataFrame(output_rows, columns=['idx', 'prefix', 'loss', 'accuracy'])
def initialize_prefix(self, prefix: torch.Tensor):
prefix = tuple(prefix.cpu().tolist())
self._avg_loss[prefix] = 10_000.0
self._avg_accuracy[prefix] = 0
self._best_prefix_by_start_token.setdefault(prefix[0], (prefix, (10_000.0,)))
def topk(self, *args, **kwargs) -> List[Tuple[int]]:
if self._topk_strategy == 'different_start_token':
return self.topk_with_different_start_token(*args, **kwargs)
elif self._topk_strategy == 'all':
return self.topk_all(*args, **kwargs)
else:
raise ValueError(f'Unknown strategy {self._topk_strategy}')
def topk_with_different_start_token(
self,
k: int,
min_occurrences: Optional[int] = None
) -> List[Tuple[int]]:
all_prefixes = [p for p, score in self._best_prefix_by_start_token.values()]
top_prefixes = self._topk_from_prefixes(
all_prefixes, k=k, min_occurrences=min_occurrences
)
if not len(top_prefixes):
# get top prefixes the first time
top_prefixes = self._topk_from_prefixes(
all_prefixes, k=k, min_occurrences=0
)
n_so_far = len(top_prefixes)
if n_so_far < k:
# fallback if we don't have enough first-tokens yet
# more_prefixes = (
# set(self.topk_all(k=k, min_occurrences=min_occurrences))
# - set(top_prefixes)
# )
num_prefixes_to_add = k - len(top_prefixes)
# num_prefixes_to_add = min(len(more_prefixes), num_prefixes_to_add)
more_prefixes = [
random.choice(top_prefixes) for _ in range(num_prefixes_to_add)
]
top_prefixes += more_prefixes
top_prefixes.sort(key=self._score)
return top_prefixes
def topk_all(self, k: int, min_occurrences: Optional[int] = None) -> List[Tuple[int]]:
all_prefixes = self._avg_loss.keys()
return self._topk_from_prefixes(
all_prefixes, k=k, min_occurrences=min_occurrences
)
def _score(self, prefix: Tuple[int]) -> Tuple[float]:
criterion = self.criterion
if criterion == 'loss':
# sort by min loss
return (self._avg_loss[prefix], )
elif criterion == 'combined':
return (-1 * round(self._avg_accuracy[prefix], 2), self._avg_loss[prefix])
else:
return (-1 * self._avg_accuracy[prefix], 2)
def _topk_from_prefixes(
self,
prefixes: Iterable[Tuple[int]],
k: int,
min_occurrences: Optional[int] = None
) -> List[Tuple[int]]:
if min_occurrences:
prefixes = {
prefix for prefix in prefixes
if len(self._all_accuracy[prefix]) > min_occurrences
}
population = [(self._score(p), p) for p in prefixes]
topk_pop = heapq.nsmallest(k, population)
topk_pop.sort(key = lambda t: t[0])
return [prefix_ids for _, prefix_ids in topk_pop]
def update(self, prefix: torch.Tensor, loss: torch.Tensor, accuracy: torch.Tensor):
# todo abstract these data strcutures into a class
prefix = tuple(prefix.cpu().flatten().tolist())
self._all_losses[prefix].append(loss.item())
self._avg_loss[prefix] = mean(self._all_losses[prefix])
self._all_accuracy[prefix].append(accuracy.item())
self._avg_accuracy[prefix] = mean(self._all_accuracy[prefix])
# track best score for each starting token
self._best_prefix_by_start_token.setdefault(prefix[0], (prefix, (1000.0,)))
score = self._score(prefix)
best_prefix, best_score = self._best_prefix_by_start_token[prefix[0]]
if score < best_score:
self._best_prefix_by_start_token[prefix[0]] = (prefix, score)
def __len__(self) -> int:
return len(self._avg_loss)
Functions
def compute_log_ppl_loss(logits: torch.Tensor, input_ids: torch.Tensor) ‑> torch.Tensor
-
Computes LM perplexity loss given logits for next tokens and original input IDs. Exponentiate this quantity if you want the actual perplexity.
Expand source code
def compute_log_ppl_loss(logits: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: """Computes LM perplexity loss given logits for next tokens and original input IDs. Exponentiate this quantity if you want the actual perplexity. """ # logits gives us the probability of each token that comes after each token in input_ids. # so they have the same shape. But we only want to compute ppl using the tokens we have, # i.e. not the first true token (which we don't have logits for) or the last predicted token # (which we don't know the true id for). so we have to shift each by one index. assert logits.shape[0:2] == input_ids.shape logits = logits[:, :-1, :] input_ids = input_ids[:, 1:] # now flatten along sequence length so we can compute crossentropy. batch_size, sequence_length, vocab_size = logits.shape assert input_ids.shape == (batch_size, sequence_length) logits = logits.reshape((batch_size * sequence_length, vocab_size)) input_ids = input_ids.reshape((batch_size * sequence_length, )) loss = torch.nn.functional.cross_entropy( input=logits, target=input_ids, reduction='mean' ) return loss
def get_prefix_from_mlm(dataloader: torch.utils.data.dataloader.DataLoader, mlm_name: str, num_candidates: int, template: str) ‑> List[str]
-
Getting prefix from MLM.
Expand source code
def get_prefix_from_mlm( dataloader: DataLoader, mlm_name: str, num_candidates: int, template: str ) -> List[str]: """ Getting prefix from MLM.""" mlm = transformers.RobertaForMaskedLM.from_pretrained(mlm_name).to(device) mlm_tokenizer = transformers.AutoTokenizer.from_pretrained(mlm_name) # template = "{mask} the two numbers to get the answer." # template = "{mask} the input number to get the answer." # template = "Return the{mask} of the input." candidates = get_token_replacements_single_mask( dataloader=dataloader, model=mlm, tokenizer=mlm_tokenizer, init_prefix_template=template, num_candidates=num_candidates ) mlm.to('cpu') # no need for mlm on GPU anymore return candidates
def get_preprefix_from_args(args: argparse.Namespace) ‑> str
-
Expand source code
def get_preprefix_from_args(args: argparse.Namespace) -> str: preprefix = '' if args.use_preprefix or not args.iprompt_preprefix_str == '': if args.iprompt_preprefix_str == '': preprefix = data.get_init_suffix( args.task_name, args.use_generic_query, args.template_num_init_string) else: preprefix = args.iprompt_preprefix_str return preprefix
def get_token_replacements_single_mask(dataloader: torch.utils.data.dataloader.DataLoader, model: transformers.models.auto.modeling_auto.AutoModelForMaskedLM, tokenizer: transformers.models.auto.tokenization_auto.AutoTokenizer, init_prefix_template: str, num_candidates: int) ‑> List[str]
-
Given a template like
{mask} the numbers
, returns thenum_candidates
most likely single-token replacements for{mask}
givenmodel
.Expand source code
def get_token_replacements_single_mask( dataloader: DataLoader, model: transformers.AutoModelForMaskedLM, tokenizer: transformers.AutoTokenizer, init_prefix_template: str, num_candidates: int )-> List[str]: """Given a template like `{mask} the numbers`, returns the `num_candidates` most likely single-token replacements for `{mask}` given `model`. """ single_mask_prefix_str = init_prefix_template.format(mask=tokenizer.mask_token) all_mask_probs = torch.zeros((tokenizer.vocab_size,), dtype=float).to(device) for idx, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataloader)): full_text = [f'{single_mask_prefix_str} {input_text}' for input_text in batch['text']] if idx == 0: print('Sample input: ', full_text[0]) inputs = tokenizer(full_text, return_tensors='pt', padding='longest') with torch.no_grad(): outputs = model(**inputs.to(device)) mask_idxs = (inputs['input_ids'] == tokenizer.mask_token_id).nonzero() # TODO: how to do this better in torch? mask_probs = outputs.logits[mask_idxs[:, 0], mask_idxs[:, 1]].log_softmax(dim=1) all_mask_probs += mask_probs.sum(dim=0) prefix_idxs = all_mask_probs.topk(num_candidates).indices return [init_prefix_template.format(mask=tokenizer.decode(idx)) for idx in prefix_idxs]
def load_lm_from_checkpoint(checkpoint: str, float16: bool) ‑> transformers.models.auto.modeling_auto.AutoModel
-
Expand source code
def load_lm_from_checkpoint( checkpoint: str, float16: bool) -> transformers.AutoModel: print(f"loading lm '{checkpoint}'") tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) llm_cls = transformers.AutoModelForSeq2SeqLM if 't5' in checkpoint else transformers.AutoModelForCausalLM if float16: if checkpoint == "EleutherAI/gpt-j-6B": print(f"loading {checkpoint} in float16...") lm = llm_cls.from_pretrained( checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True ) else: # (only certain models are pre-float16ed) print(f"trying to convert {checkpoint} to float16...") lm = llm_cls.from_pretrained( checkpoint, torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True ) # lm = lm.half() else: lm = llm_cls.from_pretrained( checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id, device_map="auto", # low_cpu_mem_usage=True ) return lm
def mean(_list: List[Union[int, float]]) ‑> float
-
Expand source code
def mean(_list: List[Union[int, float]]) -> float: return sum(_list) / len(_list)
Classes
class PrefixLoss (gamma: float, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer)
-
Computes next-token-prediction loss with optional language modeling component.
Expand source code
@dataclasses.dataclass class PrefixLoss: """Computes next-token-prediction loss with optional language modeling component. """ gamma: float tokenizer: transformers.PreTrainedTokenizer # for debugging def _compute_fluency_loss( self, logits: torch.Tensor, input_ids: torch.Tensor ) -> torch.Tensor: if self.gamma == 0: return torch.tensor(0.0).to(device) return compute_log_ppl_loss(logits=logits, input_ids=input_ids) def _compute_token_loss( self, next_token_logits: torch.Tensor, next_token_idxs: torch.Tensor, answer_mask: torch.Tensor ) -> torch.Tensor: batch_size, vocab_size = next_token_logits.shape assert next_token_idxs.shape == (batch_size,) if answer_mask is not None: assert answer_mask.shape == (vocab_size,) next_token_logits = torch.where( answer_mask[None], next_token_logits, torch.tensor(float('-inf')).to(device) ) return torch.nn.functional.cross_entropy( input=next_token_logits, target=next_token_idxs, reduction='mean' ) def __call__( self, input_ids: torch.Tensor, next_token_ids: torch.Tensor, logits: torch.Tensor, answer_mask: torch.Tensor, ) -> torch.Tensor: """Computes loss. Args: input_ids (int torch.Tensor): array of token IDs for inputs next_token_ids (int torch.Tensor): array of token IDs for the word that comes after the input logits (float torch.Tensor): logits for all output tokens, including the next one answer_mask (bool torch.Tensor): mask over tokens to remove irrelevant ones Returns: float torch.Tensor scalar, loss value (lower is better). """ fluency_loss = ( self._compute_fluency_loss( logits=logits, input_ids=input_ids ) ) token_loss = ( self._compute_token_loss( next_token_logits=logits[:, -1, :], next_token_idxs=next_token_ids, answer_mask=answer_mask, ) ) loss = token_loss + (self.gamma * fluency_loss) if DEBUG_VERBOSE: print(f">> loss for input string: {self.tokenizer.decode(input_ids[0])}") print(f"\tLoss = {loss:.3f}") return loss
Class variables
var gamma : float
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer
class PrefixModel (args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.modeling_utils.PreTrainedModel, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, preprefix: str)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class PrefixModel(nn.Module, abc.ABC): args: argparse.Namespace loss_func: PrefixLoss model: transformers.PreTrainedModel tokenizer: transformers.PreTrainedTokenizer def __init__(self, args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix: str): super().__init__() self.args = args self.loss_func = loss_func self.model = model self.tokenizer = tokenizer @property def id_to_word(self) -> Dict[int, str]: # track token-to-word mapping return {num: word for word, num in self.tokenizer.vocab.items()} @property def _is_gpt_neox(self) -> bool: return isinstance(self.model, transformers.GPTNeoXModel) or isinstance(self.model, transformers.GPTNeoXForCausalLM) @property def _is_t5(self) -> bool: return isinstance(self.model, transformers.T5ForConditionalGeneration) @property def _is_opt(self) -> bool: return isinstance(self.model, transformers.OPTForCausalLM) @property def transformer(self) -> nn.Module: if self._is_gpt_neox: return self.model._modules['gpt_neox'] elif self._is_t5: return self.model.encoder elif self._is_opt: return self.model._modules['model'].decoder else: return self.model._modules['transformer'] @property def token_embedding(self) -> nn.Embedding: if self._is_gpt_neox: return self.transformer.embed_in elif self._is_t5: return self.model.encoder.embed_tokens elif self._is_opt: return self.transformer.embed_tokens else: return self.transformer.wte @property def vocab_size(self) -> int: return self.token_embedding.weight.shape[0] # 50_257 for GPT2 @property def token_embedding_dim(self) -> int: return self.token_embedding.weight.shape[1] # often 768, or 2560 for some larger models def prepare_batch(self, batch: Dict[str, str]) -> Tuple[str, str]: """Preprocesses text from `batch['input']` and `batch['output']` for inputting into prefix model. """ if self.prefix_before_input: x_text = [f'. {prompt}' for prompt in batch['input']] y_text = [answer for answer in batch['output']] # strip whitespace at the end. else: x_text = [prompt for prompt in batch['input']] y_text = [answer.rstrip().rstrip('.') for answer in batch['output']] # strip whitespace at the end. return x_text, y_text def forward( self, input_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() def pre_epoch(self) -> None: return def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: return def compute_metrics(self) -> Dict[str, Any]: return {} def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]: """Writes stuff to disk after training.""" return {} @abc.abstractproperty def trainable_params(self) -> Iterable[nn.Parameter]: raise NotImplementedError() @abc.abstractmethod def embed_input_ids(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """To be implemented by subclasses -- embeds input ids and includes some sort of prefix, for example, in the case of prompt-tuning, by prepending a continuous embedding. """ raise NotImplementedError() def init_continuous_prefix(self, num_tokens: int) -> nn.Parameter: return nn.Parameter( self.token_embedding.weight.mean(dim=0, keepdim=True)[None].repeat(1, num_tokens, 1), requires_grad=True ) def init_discrete_prefix(self, num_tokens: int) -> nn.Parameter: if self.args.autoprompt_init_strategy == 'random': return torch.randint(low=0, high=self.tokenizer.vocab_size, size=(num_tokens,)) else: start_word_id = torch.tensor([self.tokenizer.vocab['the']], dtype=int) print(f"start_word_id = {start_word_id}") return start_word_id.repeat((num_tokens,)) def _compute_loss_with_set_prefix( self, original_input_ids: torch.Tensor, next_token_ids: torch.Tensor, possible_answer_mask: torch.Tensor, prefix_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: # feed into the model. prefix-handling is implemented in PrefixModel::forward. # and huggingface LM automatically computes language modeling loss. if self._is_t5: assert possible_answer_mask is None, "not implemented with t5 yet" blank_next_token_ids = torch.zeros( (len(original_input_ids), 0), dtype=torch.long, device=device) new_input_ids, embeddings = self.embed_input_ids( input_ids=original_input_ids, next_token_ids=blank_next_token_ids, prefix_ids=prefix_ids, ) attention_mask = ~(new_input_ids == self.tokenizer.pad_token_id) outputs = self.model( inputs_embeds=embeddings, attention_mask=attention_mask, labels=next_token_ids ) next_token_logits = outputs.logits loss = outputs.loss else: new_input_ids, embeddings = self.embed_input_ids( input_ids=original_input_ids, next_token_ids=next_token_ids, prefix_ids=prefix_ids, ) attention_mask = ~(new_input_ids == self.tokenizer.pad_token_id) # mask labels before feeding into model # huggingface supports labels of -100. # see huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2DoubleHeadsModel.forward.labels S = new_input_ids.shape[1] LS = next_token_ids.shape[1] labels = torch.where( torch.arange(S, device=device) < (S - LS), -100, new_input_ids ) labels = torch.where( labels == self.tokenizer.pad_token_id, -100, labels ) outputs = self.model( inputs_embeds=embeddings, attention_mask=attention_mask, labels=labels, ) next_token_logits = outputs.logits[:, -LS-1:-1] if possible_answer_mask is not None: BIG_NEGATIVE_NUMBER = torch.tensor(-10**10, dtype=next_token_logits.dtype, device=device) next_token_logits = torch.where(possible_answer_mask, next_token_logits, BIG_NEGATIVE_NUMBER) B, S, _V = next_token_logits.shape loss = torch.nn.functional.cross_entropy( input=next_token_logits.reshape((B * S, -1)), target=next_token_ids.reshape((B * S),), ignore_index=self.tokenizer.pad_token_id, # reduction=None ) n_correct = ( (next_token_logits.argmax(dim=-1) == next_token_ids) | (self.tokenizer.pad_token_id == next_token_ids) ).all(dim=1).sum() if DEBUG_VERBOSE: print(f">> loss for input string: {self.tokenizer.decode(new_input_ids[0])}") print(f"\tLoss = {outputs.loss:.3f}") return new_input_ids, loss, n_correct def compute_loss_and_call_backward( self, x_tokenized: transformers.BatchEncoding, y_tokenized: transformers.BatchEncoding, possible_answer_mask: Optional[torch.Tensor], full_text_tokenized: Optional[transformers.BatchEncoding] = None ) -> Tuple[torch.Tensor, int]: """Computes loss using `self.loss_func`. Returns: loss (float torch.Tensor) -- the loss num_correct (int): number of examples where prediction was correct """ original_input_ids = x_tokenized.input_ids next_token_ids = y_tokenized.input_ids[:, 0] # only compute loss over next token input_ids, outputs = self.forward(input_ids=original_input_ids, prefix_ids=None) next_token_logits = outputs.logits[:, -1, :] n_correct = ( next_token_logits.argmax(dim=-1) == next_token_ids ).int().sum() loss = self.loss_func( input_ids=input_ids, next_token_ids=next_token_ids, logits=outputs['logits'], answer_mask=possible_answer_mask ) loss.backward() return loss, n_correct def check_early_stop(self) -> bool: """Allow prefix models to stop early.""" return False
Ancestors
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
Class variables
var args : argparse.Namespace
var loss_func : PrefixLoss
var model : transformers.modeling_utils.PreTrainedModel
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer
Instance variables
var id_to_word : Dict[int, str]
-
Expand source code
@property def id_to_word(self) -> Dict[int, str]: # track token-to-word mapping return {num: word for word, num in self.tokenizer.vocab.items()}
var token_embedding : torch.nn.modules.sparse.Embedding
-
Expand source code
@property def token_embedding(self) -> nn.Embedding: if self._is_gpt_neox: return self.transformer.embed_in elif self._is_t5: return self.model.encoder.embed_tokens elif self._is_opt: return self.transformer.embed_tokens else: return self.transformer.wte
var token_embedding_dim : int
-
Expand source code
@property def token_embedding_dim(self) -> int: return self.token_embedding.weight.shape[1] # often 768, or 2560 for some larger models
var trainable_params : Iterable[torch.nn.parameter.Parameter]
-
Expand source code
@abc.abstractproperty def trainable_params(self) -> Iterable[nn.Parameter]: raise NotImplementedError()
var transformer : torch.nn.modules.module.Module
-
Expand source code
@property def transformer(self) -> nn.Module: if self._is_gpt_neox: return self.model._modules['gpt_neox'] elif self._is_t5: return self.model.encoder elif self._is_opt: return self.model._modules['model'].decoder else: return self.model._modules['transformer']
var vocab_size : int
-
Expand source code
@property def vocab_size(self) -> int: return self.token_embedding.weight.shape[0] # 50_257 for GPT2
Methods
def check_early_stop(self) ‑> bool
-
Allow prefix models to stop early.
Expand source code
def check_early_stop(self) -> bool: """Allow prefix models to stop early.""" return False
def compute_loss_and_call_backward(self, x_tokenized: transformers.tokenization_utils_base.BatchEncoding, y_tokenized: transformers.tokenization_utils_base.BatchEncoding, possible_answer_mask: Optional[torch.Tensor], full_text_tokenized: Optional[transformers.tokenization_utils_base.BatchEncoding] = None) ‑> Tuple[torch.Tensor, int]
-
Computes loss using
self.loss_func
.Returns
loss (float torch.Tensor) – the loss num_correct (int): number of examples where prediction was correct
Expand source code
def compute_loss_and_call_backward( self, x_tokenized: transformers.BatchEncoding, y_tokenized: transformers.BatchEncoding, possible_answer_mask: Optional[torch.Tensor], full_text_tokenized: Optional[transformers.BatchEncoding] = None ) -> Tuple[torch.Tensor, int]: """Computes loss using `self.loss_func`. Returns: loss (float torch.Tensor) -- the loss num_correct (int): number of examples where prediction was correct """ original_input_ids = x_tokenized.input_ids next_token_ids = y_tokenized.input_ids[:, 0] # only compute loss over next token input_ids, outputs = self.forward(input_ids=original_input_ids, prefix_ids=None) next_token_logits = outputs.logits[:, -1, :] n_correct = ( next_token_logits.argmax(dim=-1) == next_token_ids ).int().sum() loss = self.loss_func( input_ids=input_ids, next_token_ids=next_token_ids, logits=outputs['logits'], answer_mask=possible_answer_mask ) loss.backward() return loss, n_correct
def compute_metrics(self) ‑> Dict[str, Any]
-
Expand source code
def compute_metrics(self) -> Dict[str, Any]: return {}
def embed_input_ids(self, input_ids: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor]
-
To be implemented by subclasses – embeds input ids and includes some sort of prefix, for example, in the case of prompt-tuning, by prepending a continuous embedding.
Expand source code
@abc.abstractmethod def embed_input_ids(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """To be implemented by subclasses -- embeds input ids and includes some sort of prefix, for example, in the case of prompt-tuning, by prepending a continuous embedding. """ raise NotImplementedError()
def forward(self, input_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor]) ‑> Tuple[torch.Tensor, torch.Tensor]
-
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the :class:
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.Expand source code
def forward( self, input_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError()
def init_continuous_prefix(self, num_tokens: int) ‑> torch.nn.parameter.Parameter
-
Expand source code
def init_continuous_prefix(self, num_tokens: int) -> nn.Parameter: return nn.Parameter( self.token_embedding.weight.mean(dim=0, keepdim=True)[None].repeat(1, num_tokens, 1), requires_grad=True )
def init_discrete_prefix(self, num_tokens: int) ‑> torch.nn.parameter.Parameter
-
Expand source code
def init_discrete_prefix(self, num_tokens: int) -> nn.Parameter: if self.args.autoprompt_init_strategy == 'random': return torch.randint(low=0, high=self.tokenizer.vocab_size, size=(num_tokens,)) else: start_word_id = torch.tensor([self.tokenizer.vocab['the']], dtype=int) print(f"start_word_id = {start_word_id}") return start_word_id.repeat((num_tokens,))
def post_epoch(self, dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> None
-
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: return
def pre_epoch(self) ‑> None
-
Expand source code
def pre_epoch(self) -> None: return
def prepare_batch(self, batch: Dict[str, str]) ‑> Tuple[str, str]
-
Preprocesses text from
batch['input']
andbatch['output']
for inputting into prefix model.Expand source code
def prepare_batch(self, batch: Dict[str, str]) -> Tuple[str, str]: """Preprocesses text from `batch['input']` and `batch['output']` for inputting into prefix model. """ if self.prefix_before_input: x_text = [f'. {prompt}' for prompt in batch['input']] y_text = [answer for answer in batch['output']] # strip whitespace at the end. else: x_text = [prompt for prompt in batch['input']] y_text = [answer.rstrip().rstrip('.') for answer in batch['output']] # strip whitespace at the end. return x_text, y_text
def serialize(self, eval_dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> Dict[str, Any]
-
Writes stuff to disk after training.
Expand source code
def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]: """Writes stuff to disk after training.""" return {}
class PrefixPool (tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, criterion: str, topk_strategy: str = 'different_start_token', verbose: bool = False)
-
Tracks a pool of candidate prefixes and their associated metrics over time.
Expand source code
class PrefixPool: """Tracks a pool of candidate prefixes and their associated metrics over time.""" criterion: str tokenizer: transformers.PreTrainedTokenizer verbose: bool # _all_losses: Dict[Tuple[int], List[float]] _avg_loss: Dict[Tuple[int], float] _all_accuracy: Dict[Tuple[int], List[float]] _avg_accuracy: Dict[Tuple[int], float] _best_prefix_by_start_token: Dict[int, Tuple[Tuple[int], float]] def __init__(self, tokenizer: transformers.PreTrainedTokenizer, criterion: str, topk_strategy: str = 'different_start_token', verbose: bool = False): self.tokenizer = tokenizer self.criterion = criterion # tuple (input_ids) -> float (loss) self._avg_loss = {} self._all_losses = collections.defaultdict(list) # tuple (input_ids) -> int (n_correct) self._avg_accuracy = {} self._all_accuracy = collections.defaultdict(list) # self._best_prefix_by_start_token = {} # self._topk_strategy = topk_strategy # ['different_start_token', 'all'] self.verbose = verbose @property def prefixes(self) -> List[Tuple[int]]: return self._avg_loss.keys() @property def num_start_tokens(self) -> int: """Number of different start tokens seen across all prefixes.""" return len(self._best_prefix_by_start_token.keys()) def print(self, topk: int, min_occurrences: int = 2) -> pd.DataFrame: top_token_ids = self.topk(k=topk, min_occurrences=min_occurrences) ########################### Debugging code ########################## # import pandas as pd # vd = pd.DataFrame(self._avg_loss.items(), columns=['prefix', 'loss']) # vd['prefix_str'] = vd['prefix'].map(self.tokenizer.decode) # vd['n'] = vd['prefix'].map(lambda p: len(self._all_losses[p])) # vd.sort_values(by='loss')["prefix_str"].iloc[:25] # vd.sort_values(by=['n', 'loss'], ascending=[False, True])[["n", "prefix_str"]].iloc[:25] ##################################################################### if not len(top_token_ids): return print_str = " ".join(((" " * 45), ("*" * 20), "Population", ("*" * 20))) + "\n" output_rows = [] for idx, token_ids in enumerate(top_token_ids): prefix = self.tokenizer.decode(list(token_ids)) loss = self._avg_loss[token_ids] acc = self._avg_accuracy[token_ids] prefix_str = "{:>65}".format(prefix.replace("\n", "\\\\n")) loss_str = f"{loss:.3f}" acc_str = f"{acc*100:.1f}" print_str += " ".join((prefix_str, "\t\t", loss_str, "\t\t", acc_str)) + "\n" output_rows.append([idx, prefix, loss, acc]) if self.verbose: print(print_str) return pd.DataFrame(output_rows, columns=['idx', 'prefix', 'loss', 'accuracy']) def initialize_prefix(self, prefix: torch.Tensor): prefix = tuple(prefix.cpu().tolist()) self._avg_loss[prefix] = 10_000.0 self._avg_accuracy[prefix] = 0 self._best_prefix_by_start_token.setdefault(prefix[0], (prefix, (10_000.0,))) def topk(self, *args, **kwargs) -> List[Tuple[int]]: if self._topk_strategy == 'different_start_token': return self.topk_with_different_start_token(*args, **kwargs) elif self._topk_strategy == 'all': return self.topk_all(*args, **kwargs) else: raise ValueError(f'Unknown strategy {self._topk_strategy}') def topk_with_different_start_token( self, k: int, min_occurrences: Optional[int] = None ) -> List[Tuple[int]]: all_prefixes = [p for p, score in self._best_prefix_by_start_token.values()] top_prefixes = self._topk_from_prefixes( all_prefixes, k=k, min_occurrences=min_occurrences ) if not len(top_prefixes): # get top prefixes the first time top_prefixes = self._topk_from_prefixes( all_prefixes, k=k, min_occurrences=0 ) n_so_far = len(top_prefixes) if n_so_far < k: # fallback if we don't have enough first-tokens yet # more_prefixes = ( # set(self.topk_all(k=k, min_occurrences=min_occurrences)) # - set(top_prefixes) # ) num_prefixes_to_add = k - len(top_prefixes) # num_prefixes_to_add = min(len(more_prefixes), num_prefixes_to_add) more_prefixes = [ random.choice(top_prefixes) for _ in range(num_prefixes_to_add) ] top_prefixes += more_prefixes top_prefixes.sort(key=self._score) return top_prefixes def topk_all(self, k: int, min_occurrences: Optional[int] = None) -> List[Tuple[int]]: all_prefixes = self._avg_loss.keys() return self._topk_from_prefixes( all_prefixes, k=k, min_occurrences=min_occurrences ) def _score(self, prefix: Tuple[int]) -> Tuple[float]: criterion = self.criterion if criterion == 'loss': # sort by min loss return (self._avg_loss[prefix], ) elif criterion == 'combined': return (-1 * round(self._avg_accuracy[prefix], 2), self._avg_loss[prefix]) else: return (-1 * self._avg_accuracy[prefix], 2) def _topk_from_prefixes( self, prefixes: Iterable[Tuple[int]], k: int, min_occurrences: Optional[int] = None ) -> List[Tuple[int]]: if min_occurrences: prefixes = { prefix for prefix in prefixes if len(self._all_accuracy[prefix]) > min_occurrences } population = [(self._score(p), p) for p in prefixes] topk_pop = heapq.nsmallest(k, population) topk_pop.sort(key = lambda t: t[0]) return [prefix_ids for _, prefix_ids in topk_pop] def update(self, prefix: torch.Tensor, loss: torch.Tensor, accuracy: torch.Tensor): # todo abstract these data strcutures into a class prefix = tuple(prefix.cpu().flatten().tolist()) self._all_losses[prefix].append(loss.item()) self._avg_loss[prefix] = mean(self._all_losses[prefix]) self._all_accuracy[prefix].append(accuracy.item()) self._avg_accuracy[prefix] = mean(self._all_accuracy[prefix]) # track best score for each starting token self._best_prefix_by_start_token.setdefault(prefix[0], (prefix, (1000.0,))) score = self._score(prefix) best_prefix, best_score = self._best_prefix_by_start_token[prefix[0]] if score < best_score: self._best_prefix_by_start_token[prefix[0]] = (prefix, score) def __len__(self) -> int: return len(self._avg_loss)
Class variables
var criterion : str
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer
var verbose : bool
Instance variables
var num_start_tokens : int
-
Number of different start tokens seen across all prefixes.
Expand source code
@property def num_start_tokens(self) -> int: """Number of different start tokens seen across all prefixes.""" return len(self._best_prefix_by_start_token.keys())
var prefixes : List[Tuple[int]]
-
Expand source code
@property def prefixes(self) -> List[Tuple[int]]: return self._avg_loss.keys()
Methods
def initialize_prefix(self, prefix: torch.Tensor)
-
Expand source code
def initialize_prefix(self, prefix: torch.Tensor): prefix = tuple(prefix.cpu().tolist()) self._avg_loss[prefix] = 10_000.0 self._avg_accuracy[prefix] = 0 self._best_prefix_by_start_token.setdefault(prefix[0], (prefix, (10_000.0,)))
def print(self, topk: int, min_occurrences: int = 2) ‑> pandas.core.frame.DataFrame
-
Expand source code
def print(self, topk: int, min_occurrences: int = 2) -> pd.DataFrame: top_token_ids = self.topk(k=topk, min_occurrences=min_occurrences) ########################### Debugging code ########################## # import pandas as pd # vd = pd.DataFrame(self._avg_loss.items(), columns=['prefix', 'loss']) # vd['prefix_str'] = vd['prefix'].map(self.tokenizer.decode) # vd['n'] = vd['prefix'].map(lambda p: len(self._all_losses[p])) # vd.sort_values(by='loss')["prefix_str"].iloc[:25] # vd.sort_values(by=['n', 'loss'], ascending=[False, True])[["n", "prefix_str"]].iloc[:25] ##################################################################### if not len(top_token_ids): return print_str = " ".join(((" " * 45), ("*" * 20), "Population", ("*" * 20))) + "\n" output_rows = [] for idx, token_ids in enumerate(top_token_ids): prefix = self.tokenizer.decode(list(token_ids)) loss = self._avg_loss[token_ids] acc = self._avg_accuracy[token_ids] prefix_str = "{:>65}".format(prefix.replace("\n", "\\\\n")) loss_str = f"{loss:.3f}" acc_str = f"{acc*100:.1f}" print_str += " ".join((prefix_str, "\t\t", loss_str, "\t\t", acc_str)) + "\n" output_rows.append([idx, prefix, loss, acc]) if self.verbose: print(print_str) return pd.DataFrame(output_rows, columns=['idx', 'prefix', 'loss', 'accuracy'])
def topk(self, *args, **kwargs) ‑> List[Tuple[int]]
-
Expand source code
def topk(self, *args, **kwargs) -> List[Tuple[int]]: if self._topk_strategy == 'different_start_token': return self.topk_with_different_start_token(*args, **kwargs) elif self._topk_strategy == 'all': return self.topk_all(*args, **kwargs) else: raise ValueError(f'Unknown strategy {self._topk_strategy}')
def topk_all(self, k: int, min_occurrences: Optional[int] = None) ‑> List[Tuple[int]]
-
Expand source code
def topk_all(self, k: int, min_occurrences: Optional[int] = None) -> List[Tuple[int]]: all_prefixes = self._avg_loss.keys() return self._topk_from_prefixes( all_prefixes, k=k, min_occurrences=min_occurrences )
def topk_with_different_start_token(self, k: int, min_occurrences: Optional[int] = None) ‑> List[Tuple[int]]
-
Expand source code
def topk_with_different_start_token( self, k: int, min_occurrences: Optional[int] = None ) -> List[Tuple[int]]: all_prefixes = [p for p, score in self._best_prefix_by_start_token.values()] top_prefixes = self._topk_from_prefixes( all_prefixes, k=k, min_occurrences=min_occurrences ) if not len(top_prefixes): # get top prefixes the first time top_prefixes = self._topk_from_prefixes( all_prefixes, k=k, min_occurrences=0 ) n_so_far = len(top_prefixes) if n_so_far < k: # fallback if we don't have enough first-tokens yet # more_prefixes = ( # set(self.topk_all(k=k, min_occurrences=min_occurrences)) # - set(top_prefixes) # ) num_prefixes_to_add = k - len(top_prefixes) # num_prefixes_to_add = min(len(more_prefixes), num_prefixes_to_add) more_prefixes = [ random.choice(top_prefixes) for _ in range(num_prefixes_to_add) ] top_prefixes += more_prefixes top_prefixes.sort(key=self._score) return top_prefixes
def update(self, prefix: torch.Tensor, loss: torch.Tensor, accuracy: torch.Tensor)
-
Expand source code
def update(self, prefix: torch.Tensor, loss: torch.Tensor, accuracy: torch.Tensor): # todo abstract these data strcutures into a class prefix = tuple(prefix.cpu().flatten().tolist()) self._all_losses[prefix].append(loss.item()) self._avg_loss[prefix] = mean(self._all_losses[prefix]) self._all_accuracy[prefix].append(accuracy.item()) self._avg_accuracy[prefix] = mean(self._all_accuracy[prefix]) # track best score for each starting token self._best_prefix_by_start_token.setdefault(prefix[0], (prefix, (1000.0,))) score = self._score(prefix) best_prefix, best_score = self._best_prefix_by_start_token[prefix[0]] if score < best_score: self._best_prefix_by_start_token[prefix[0]] = (prefix, score)