Module imodelsx.iprompt.autoprompt
Classes
class AutoPrompt (args: argparse.Namespace,
loss_func: PrefixLoss,
model: transformers.modeling_utils.PreTrainedModel,
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer,
preprefix: str = '')-
Expand source code
class AutoPrompt(HotFlip): args: argparse.Namespace loss_func: PrefixLoss model: transformers.PreTrainedModel tokenizer: transformers.PreTrainedTokenizer prefix_ids: torch.Tensor prefix_embedding: torch.nn.Parameter preprefix: str def __init__( self, args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix: str = '' ): super().__init__( args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix ) self._do_final_reranking = args.iprompt_do_final_reranking # AutoPrompt-specific parameters. self._num_candidates_per_prefix_token = 32 # V_cand in autoprompt paper # This helps us know which were the best prefixes to return over time self._prefix_pool = PrefixPool( tokenizer=self.tokenizer, criterion='loss' # in ['loss', 'acc', 'combined'] ) self._autoprompt_verbose = True self._num_min_occurrences = 1 # Will rank and save this many prefixes at the end of training. self._num_prefixes_to_test = 64 def test_prefixes( self, prefixes: List[Tuple[int]], eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Computes loss & accuracy for each prefix on data in dataloader. Used to rank prefixes at the end of training. """ all_candidate_losses = torch.zeros(len(prefixes), dtype=torch.float32) all_candidate_n_correct = torch.zeros( len(prefixes), dtype=torch.float32) total_n = 0 for batch in tqdm.tqdm(eval_dataloader, desc=f'evaluating {len(prefixes)} prefixes'): if (self.args.n_shots > 1) and (self.args.single_shot_loss): ## batch['input'] = batch['last_input'] ## x_text, y_text = self.prepare_batch(batch=batch) tok = functools.partial( self.tokenizer, return_tensors='pt', padding='longest', truncation=True, max_length=self.args.max_length # TODO set max_length on self ) x_tokenized = tok(x_text).to(device) y_tokenized = tok(y_text).to(device) total_n += len(x_tokenized.input_ids) next_token_ids = y_tokenized.input_ids for i in range(len(prefixes)): with torch.no_grad(): _cand_input_ids, cand_loss, cand_n_correct = ( self._compute_loss_with_set_prefix( original_input_ids=x_tokenized.input_ids, next_token_ids=next_token_ids, possible_answer_mask=possible_answer_mask, prefix_ids=torch.tensor(prefixes[i]).to(device), ) ) all_candidate_losses[i] += cand_loss.item() all_candidate_n_correct[i] += cand_n_correct.item() return all_candidate_losses.cpu().tolist(), (all_candidate_n_correct / total_n).cpu().tolist() def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]: """Writes stuff to disk. Saves other stuff to save as full results file. """ # Uncomment following lines to save all the prefixes we tested. # save_dir = self.args.save_dir_unique # os.makedirs(save_dir, exist_ok=True) # pickle.dump(self._prefix_pool, open(os.path.join(save_dir, 'prefix_pool.p'), 'wb')) all_prefixes = self._prefix_pool.topk_all( k=self._num_prefixes_to_test, min_occurrences=3) if not len(all_prefixes): # In the case where we get no prefixes here (i.e. prompt generation # only ran for a single step) just take anything from prefix pool. all_prefixes = random.choices(list(self._prefix_pool.prefixes), k=self._num_prefixes_to_test) if self._do_final_reranking: all_losses, all_accuracies = self.test_prefixes( prefixes=all_prefixes, eval_dataloader=eval_dataloader, possible_answer_mask=possible_answer_mask ) df = pd.DataFrame( zip(*[all_prefixes, all_losses, all_accuracies]), columns=['prefix', 'loss', 'accuracy'] ) df = df.sort_values(by=['accuracy', 'loss'], ascending=[ False, True]).reset_index() else: all_prefixes = list(self._prefix_pool.prefixes) all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes] all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes] df = pd.DataFrame( zip(*[all_prefixes, all_losses, all_accuracies]), columns=['prefix', 'loss', 'accuracy'] ) df = df.sort_values(by='accuracy', ascending=False).reset_index() df['prefix_str'] = df['prefix'].map(self.tokenizer.decode) df['n_queries'] = df['prefix'].map( lambda p_ids: len(self._prefix_pool._all_losses[p_ids])) print('Final prefixes') print(df.head()) return { "prefix_ids": df['prefix'].tolist(), "prefixes": df['prefix_str'].tolist(), "prefix_train_acc": df['accuracy'].tolist(), "prefix_train_loss": df['loss'].tolist(), "prefix_n_queries": df['n_queries'].tolist(), } def compute_loss_and_call_backward( self, x_tokenized: transformers.BatchEncoding, y_tokenized: transformers.BatchEncoding, possible_answer_mask: torch.Tensor, full_text_tokenized: Optional[transformers.BatchEncoding] = None ) -> Tuple[torch.Tensor, int]: """Computes loss using `self.loss_func`. Returns: loss (float torch.Tensor) -- the loss num_correct (int): number of examples where prediction was correct """ original_input_ids = x_tokenized.input_ids next_token_ids = y_tokenized.input_ids # only compute loss over next token current_input_ids, current_loss, current_n_correct = self._compute_loss_with_set_prefix( original_input_ids=original_input_ids, next_token_ids=next_token_ids, possible_answer_mask=possible_answer_mask, prefix_ids=None, ) current_loss.backward() self._autoprompt_verbose: print( f'** {self.tokenizer.decode(self.prefix_ids)}: {current_loss:.2f}') # track running accuracy of this prefix. self._prefix_pool.update( prefix=self.prefix_ids, loss=current_loss, accuracy=(current_n_correct/len(original_input_ids)) ) # print an update. self._prefix_pool.print(topk=10, min_occurrences=1) # # Get top token replacements # token_grads = self._prefix_token_grad if self._is_t5: # t5 has extra vocab tokens for no reason: # https://github.com/huggingface/transformers/issues/4875#issuecomment-647634437 assert token_grads.shape == ( self._num_tokens, len(self.tokenizer.vocab) + 28 ) token_grads = token_grads[:, :-28] assert token_grads.shape == ( self._num_tokens, len(self.tokenizer.vocab)) top_tokens_per_position = ( token_grads.topk( k=self._num_candidates_per_prefix_token, dim=1, largest=False).indices ) assert top_tokens_per_position.shape == ( self._num_tokens, self._num_candidates_per_prefix_token) top_swap_tokens = top_tokens_per_position[self._swap_token_idx, :] # # Get most likely tokens. # top_swap_tokens = token_grads.argsort(descending=False).flatten() top_swap_tokens = top_swap_tokens[0: self._num_candidates_per_prefix_token] # rank candidates mask = torch.nn.functional.one_hot( torch.tensor(self._swap_token_idx), num_classes=self._num_tokens ).bool().to(device) candidate_prefix_ids = torch.where( mask, top_swap_tokens[:, None], self.prefix_ids[None, :]) is_current_prefix_mask = ( candidate_prefix_ids == self.prefix_ids).all(dim=1) candidate_prefix_ids = candidate_prefix_ids[~is_current_prefix_mask] # get best prefix num_candidates = len(candidate_prefix_ids) all_candidate_losses = torch.zeros( num_candidates, dtype=float).to(device) all_n_correct = torch.zeros(num_candidates, dtype=int).to(device) for i in range(num_candidates): with torch.no_grad(): cand_input_ids, cand_loss, cand_n_correct = ( self._compute_loss_with_set_prefix( original_input_ids=original_input_ids, next_token_ids=next_token_ids, possible_answer_mask=possible_answer_mask, prefix_ids=candidate_prefix_ids[i], ) ) all_candidate_losses[i] = cand_loss all_n_correct[i] = cand_n_correct # self._autoprompt_verbose: print( # f'** \t{self.tokenizer.decode(candidate_prefix_ids[i])}: {cand_loss:.2f}') self._prefix_pool.update( prefix=candidate_prefix_ids[i], loss=cand_loss, accuracy=(cand_n_correct / len(original_input_ids)) ) # randomly change the token to swap self._swap_token_idx = random.randint(0, (self._num_tokens-1)) # get best prefix we've seen if all_candidate_losses.min() < current_loss: best_prefix = candidate_prefix_ids[all_candidate_losses.argmin()] best_prefix_loss = all_candidate_losses.min() best_prefix_n_correct = all_n_correct[all_candidate_losses.argmin( )] if self._autoprompt_verbose: print("** set new prefix", best_prefix) else: best_prefix = self.prefix_ids best_prefix_loss = current_loss best_prefix_n_correct = current_n_correct if self._autoprompt_verbose: print("** set same prefix", best_prefix) self._set_prefix_ids(best_prefix) return best_prefix_loss, best_prefix_n_correct def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: # # Get candidate IDs for every position. # pass
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- HotFlip
- PrefixModel
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
Class variables
var args : argparse.Namespace
var loss_func : PrefixLoss
var model : transformers.modeling_utils.PreTrainedModel
var prefix_embedding : torch.nn.parameter.Parameter
var prefix_ids : torch.Tensor
var preprefix : str
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer
Methods
def post_epoch(self,
dataloader: torch.utils.data.dataloader.DataLoader,
possible_answer_mask: torch.Tensor) ‑> None-
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: # # Get candidate IDs for every position. # pass
def serialize(self,
eval_dataloader: torch.utils.data.dataloader.DataLoader,
possible_answer_mask: torch.Tensor) ‑> Dict[str, Any]-
Expand source code
def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]: """Writes stuff to disk. Saves other stuff to save as full results file. """ # Uncomment following lines to save all the prefixes we tested. # save_dir = self.args.save_dir_unique # os.makedirs(save_dir, exist_ok=True) # pickle.dump(self._prefix_pool, open(os.path.join(save_dir, 'prefix_pool.p'), 'wb')) all_prefixes = self._prefix_pool.topk_all( k=self._num_prefixes_to_test, min_occurrences=3) if not len(all_prefixes): # In the case where we get no prefixes here (i.e. prompt generation # only ran for a single step) just take anything from prefix pool. all_prefixes = random.choices(list(self._prefix_pool.prefixes), k=self._num_prefixes_to_test) if self._do_final_reranking: all_losses, all_accuracies = self.test_prefixes( prefixes=all_prefixes, eval_dataloader=eval_dataloader, possible_answer_mask=possible_answer_mask ) df = pd.DataFrame( zip(*[all_prefixes, all_losses, all_accuracies]), columns=['prefix', 'loss', 'accuracy'] ) df = df.sort_values(by=['accuracy', 'loss'], ascending=[ False, True]).reset_index() else: all_prefixes = list(self._prefix_pool.prefixes) all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes] all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes] df = pd.DataFrame( zip(*[all_prefixes, all_losses, all_accuracies]), columns=['prefix', 'loss', 'accuracy'] ) df = df.sort_values(by='accuracy', ascending=False).reset_index() df['prefix_str'] = df['prefix'].map(self.tokenizer.decode) df['n_queries'] = df['prefix'].map( lambda p_ids: len(self._prefix_pool._all_losses[p_ids])) print('Final prefixes') print(df.head()) return { "prefix_ids": df['prefix'].tolist(), "prefixes": df['prefix_str'].tolist(), "prefix_train_acc": df['accuracy'].tolist(), "prefix_train_loss": df['loss'].tolist(), "prefix_n_queries": df['n_queries'].tolist(), }
Writes stuff to disk. Saves other stuff to save as full results file.
def test_prefixes(self,
prefixes: List[Tuple[int]],
eval_dataloader: torch.utils.data.dataloader.DataLoader,
possible_answer_mask: torch.Tensor) ‑> Tuple[torch.Tensor, torch.Tensor]-
Expand source code
def test_prefixes( self, prefixes: List[Tuple[int]], eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Computes loss & accuracy for each prefix on data in dataloader. Used to rank prefixes at the end of training. """ all_candidate_losses = torch.zeros(len(prefixes), dtype=torch.float32) all_candidate_n_correct = torch.zeros( len(prefixes), dtype=torch.float32) total_n = 0 for batch in tqdm.tqdm(eval_dataloader, desc=f'evaluating {len(prefixes)} prefixes'): if (self.args.n_shots > 1) and (self.args.single_shot_loss): ## batch['input'] = batch['last_input'] ## x_text, y_text = self.prepare_batch(batch=batch) tok = functools.partial( self.tokenizer, return_tensors='pt', padding='longest', truncation=True, max_length=self.args.max_length # TODO set max_length on self ) x_tokenized = tok(x_text).to(device) y_tokenized = tok(y_text).to(device) total_n += len(x_tokenized.input_ids) next_token_ids = y_tokenized.input_ids for i in range(len(prefixes)): with torch.no_grad(): _cand_input_ids, cand_loss, cand_n_correct = ( self._compute_loss_with_set_prefix( original_input_ids=x_tokenized.input_ids, next_token_ids=next_token_ids, possible_answer_mask=possible_answer_mask, prefix_ids=torch.tensor(prefixes[i]).to(device), ) ) all_candidate_losses[i] += cand_loss.item() all_candidate_n_correct[i] += cand_n_correct.item() return all_candidate_losses.cpu().tolist(), (all_candidate_n_correct / total_n).cpu().tolist()
Computes loss & accuracy for each prefix on data in dataloader. Used to rank prefixes at the end of training.
Inherited members