Module imodelsx.iprompt.ipromptx
Classes
class iPrompt (loss_func: PrefixLoss,
model: transformers.modeling_utils.PreTrainedModel,
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer,
preprefix_str: str = '',
prefix_before_input: bool = True,
pop_criterion: str = 'loss',
pop_topk_strategy: str = 'different_start_token',
pop_size: int = 8,
num_mutations: int = 4,
num_random_generations: int = 4,
generation_repetition_penalty: float = 2.0,
generation_temp: float = 1.0,
generation_top_p: float = 1.0,
do_final_reranking: bool = False,
early_stopping_steps: int = -1,
num_learned_tokens: int = 1,
max_length: int = 128,
verbose: int = 0,
llm_float16: bool = True,
generation_checkpoint: str = '',
n_shots: int = 1,
single_shot_loss: bool = True,
llm_candidate_regeneration_prompt_start: str = 'Data:',
llm_candidate_regeneration_prompt_end: str = 'Prompt:')-
Expand source code
class iPrompt(AutoPrompt): def __init__( self, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix_str: str = '', prefix_before_input: bool = True, pop_criterion: str = 'loss', pop_topk_strategy: str = 'different_start_token', pop_size: int = 8, num_mutations: int = 4, num_random_generations: int = 4, generation_repetition_penalty: float = 2.0, generation_temp: float = 1.0, generation_top_p: float = 1.0, do_final_reranking: bool = False, early_stopping_steps: int = -1, num_learned_tokens: int = 1, max_length: int = 128, verbose: int = 0, llm_float16: bool = True, generation_checkpoint: str = '', n_shots: int = 1, single_shot_loss: bool = True, llm_candidate_regeneration_prompt_start: str = 'Data:', llm_candidate_regeneration_prompt_end: str = 'Prompt:', ): args = argparse.Namespace() args.prefix_before_input = prefix_before_input args.num_learned_tokens = num_learned_tokens args.hotflip_num_candidates = None args.autoprompt_init_strategy = None args.save_dir_unique = '.' args.n_shots = n_shots args.single_shot_loss = single_shot_loss args.max_length = max_length args.iprompt_do_final_reranking = do_final_reranking super().__init__( args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix='' ) self.tokenizer = tokenizer self.tokenizer.add_special_tokens = False #################################################################### # iPrompt-specific parameters self._pop_size = pop_size self._topk_pop_sample = (self._pop_size + 4) # sample next population from this num of top things. set higher for more randomness. self._num_mutations_per_ex = num_mutations # num mutations for each population item self._num_random_generations = num_random_generations # extra random examples to throw in there (won't get mutated) self._generation_temp = generation_temp self._generation_top_p = generation_top_p self._generation_repetition_penalty = generation_repetition_penalty # 1 means no penalty self._pop_initialized = False self._generation_bad_words_ids = [ self.tokenizer.encode('\n'), self.tokenizer.encode('\n\n'), self.tokenizer.encode('\n\n\n') ] #################################################################### self.conditioning_strategy = '' # This arg is only used for ablations. self.other_generation_model = None if generation_checkpoint: self.other_generation_model = load_lm_from_checkpoint( generation_checkpoint, float16=llm_float16 ) #################################################################### self._prefix_pool = PrefixPool( tokenizer=self.tokenizer, criterion=pop_criterion, # 'loss' # in ['loss', 'acc', 'combined'] topk_strategy=pop_topk_strategy, verbose=verbose, ) # Suff to track for early stopping self._early_stopping_steps = early_stopping_steps self._last_population = None self._steps_since_new_population = 0 #################################################################### self.prefix_ids = None if len(preprefix_str): self.preprefix_ids = torch.tensor( self.tokenizer.encode(preprefix_str, add_special_tokens=False), dtype=int ).to(device) else: self.preprefix_ids = torch.tensor([], dtype=int).to(device) prompt_str = preprefix_str.lstrip() prompt_str = (' ' + prompt_str) if len(prompt_str) else '' self._pre_data_token_ids = self._pre_data_token_ids = self.tokenizer( f"{llm_candidate_regeneration_prompt_start}\n\n", return_tensors='pt').input_ids.to(device) self._post_data_token_ids = self.tokenizer( f"\n\n{llm_candidate_regeneration_prompt_end}" + prompt_str, return_tensors='pt').input_ids.to(device) self.llm_candidate_regeneration_prompt_start = llm_candidate_regeneration_prompt_start self.llm_candidate_regeneration_prompt_end = llm_candidate_regeneration_prompt_end #################################################################### self._iprompt_verbose = verbose self._step = 0 def serialize(self, eval_dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> Dict[str, Any]: r = super().serialize(eval_dataloader=eval_dataloader, possible_answer_mask=possible_answer_mask) r["topk_pop_sample"] = self._topk_pop_sample r["pop_size"] = self._pop_size r["num_mutations_per_ex"] = self._num_mutations_per_ex r["num_random_generations"] = self._num_random_generations r["generation_temp"] = self._generation_temp r["generation_top_p"] = self._generation_top_p r["generation_repetition_penalty"] = self._generation_repetition_penalty r["generation_bad_words_ids"] = self._generation_bad_words_ids r["pre_data_prompt_str"] = self.tokenizer.decode(self._pre_data_token_ids.flatten()) r["post_data_prompt_str"] = self.tokenizer.decode(self._post_data_token_ids.flatten()) return r def _initialize_pop_once(self, full_text_ids: torch.Tensor): if self._pop_initialized: return while len(self._prefix_pool) < self._pop_size: conditional_input_ids = random.choice(full_text_ids)[None] num_conditional_tokens = conditional_input_ids.numel() input_ids = self._generate( input_ids=conditional_input_ids, num_conditional_tokens=num_conditional_tokens ).squeeze() assert input_ids.numel() == self._num_tokens self._prefix_pool.initialize_prefix(input_ids) self._pop_initialized = True @property def _generation_model(self) -> transformers.AutoModelForCausalLM: """Returns the model to use for generation. We optionally support using different models for generation and discrimination. However, by default, we use the same model for both. """ if self.other_generation_model: return self.other_generation_model else: return self.model def _generate(self, input_ids: torch.Tensor, num_conditional_tokens: int) -> torch.Tensor: """Generates some text using the model and preset hparams. If `num_conditional_tokens` > 0, generates extra text because there was an additional prefix set. """ attention_mask = ~(input_ids == self.tokenizer.pad_token_id) assert attention_mask.shape == input_ids.shape if self._is_t5: output_length = self._num_tokens + 1 # will add pad token else: output_length = self._num_tokens + num_conditional_tokens # print("iPrompt._generate", input_ids.shape, "//", self.tokenizer.decode(input_ids[0])) g = self._generation_model.generate( input_ids=input_ids, attention_mask=attention_mask, min_length=output_length, max_length=output_length, temperature=self._generation_temp, top_p=self._generation_top_p, repetition_penalty=self._generation_repetition_penalty, bad_words_ids=self._generation_bad_words_ids, do_sample=True ) if self._is_t5: assert (g[:, 0] == 0).all() g = g[:, 1:] else: # Split off the conditional part, we only want the prefix part, which # starts after the conditional part. g = g[:, num_conditional_tokens:] if self._iprompt_verbose: # Print a random one (but remove padded tokens and newlines) idx = random.choice(range(len(input_ids))) # idx_attention_mask = torch.cat( # (attention_mask[idx], torch.ones(self._num_tokens).to(device)), dim=0 # ).bool() random_sentence_ids = g[idx] # print(">>", self.tokenizer.decode(random_sentence_ids).replace('\n', '\\n')) return g def _select_pop_topk(self, k: int, min_occurrences: int = None) -> List[Tuple[int]]: return self._prefix_pool.topk(k=k, min_occurrences=min_occurrences) def _track_early_stopping(self): """Track changes in population to tell when to stop early.""" __n_early_stop = 5 population = set(self._select_pop_topk(k=__n_early_stop, min_occurrences=3)) if (len(population) == __n_early_stop) and (self._last_population == population): self._steps_since_new_population += 1 if self._iprompt_verbose: print("self._steps_since_new_population:", self._steps_since_new_population) else: self._last_population = population self._steps_since_new_population = 0 if self._iprompt_verbose: print("new population:", [self.tokenizer.decode(p) for p in sorted(population)]) def check_early_stop(self) -> bool: """Allow prefix models to stop early.""" if self._early_stopping_steps == -1: return False return self._steps_since_new_population >= self._early_stopping_steps def _get_population_and_random_generations(self, full_text_ids: torch.Tensor) -> torch.Tensor: population_pool = self._select_pop_topk(k=self._topk_pop_sample) # if self._iprompt_verbose: # print("population_pool:", [self.tokenizer.decode(p) for p in population_pool]) population = random.sample(population_pool, self._pop_size) population = torch.tensor(population).to(device) if self._num_random_generations > 0: random_idxs = torch.randint( low=0, high=len(full_text_ids), size=(self._num_random_generations,) ) random_full_text_ids = full_text_ids[random_idxs] num_conditional_tokens = full_text_ids.shape[1] random_population = self._generate( input_ids=random_full_text_ids, num_conditional_tokens=num_conditional_tokens ) full_population = torch.cat((population, random_population), dim=0) else: # Support case where _num_random_generations is set to 0. full_population = population assert full_population.shape == ( self._pop_size + self._num_random_generations, self._num_tokens ) return full_population def _mutate(self, population_input_ids: torch.Tensor, full_text_ids: torch.Tensor) -> List[torch.Tensor]: """Mutates a population of prefixes. Truncates to a random place and then generates new options to try. Args: population_input_ids (int torch.Tensor): input IDs for each prefix in population full_text_ids (int torch.Tensor): input IDs for each data item in the batch. Intended be used to do prefix generation conditioned on data """ assert population_input_ids.shape[1] == self._num_tokens input_ids = population_input_ids.repeat((self._num_mutations_per_ex, 1)) self._roll_before_truncation = False if self._roll_before_truncation: roll_amount = random.randint(0, self._num_tokens-1) input_ids = torch.roll(input_ids, roll_amount, dims=[1]) truncate_position = random.randint(0, self._num_tokens-1) truncated_input_ids = input_ids[:, :truncate_position] random_idxs = torch.randint(low=0, high=len(full_text_ids), size=(len(input_ids), )) random_full_text_ids = full_text_ids[random_idxs] conditional_input_ids = torch.cat((random_full_text_ids, truncated_input_ids), dim=1) num_conditional_tokens = full_text_ids.shape[1] new_input_ids = self._generate( input_ids=conditional_input_ids, num_conditional_tokens=num_conditional_tokens ) return new_input_ids def _score_population( self, x_tokenized: transformers.BatchEncoding, y_tokenized: transformers.BatchEncoding, population_input_ids: torch.Tensor, possible_answer_mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Scores a population of prefixes and updates `self._genetic_pool`.""" pop_size = len(population_input_ids) all_candidate_losses = torch.zeros(pop_size, dtype=float).to(device) all_accuracy = torch.zeros(pop_size, dtype=float).to(device) all_candidate_n_correct = torch.zeros(pop_size, dtype=int).to(device) for i in range(pop_size): with torch.no_grad(): _cand_input_ids, cand_loss, cand_n_correct = ( self._compute_loss_with_set_prefix( original_input_ids=x_tokenized.input_ids, next_token_ids=y_tokenized.input_ids, possible_answer_mask=possible_answer_mask, prefix_ids=population_input_ids[i], ) ) cand_accuracy = cand_n_correct / len(x_tokenized.input_ids) all_candidate_n_correct[i] += cand_n_correct all_candidate_losses[i] += cand_loss all_accuracy[i] += cand_accuracy for i in range(pop_size): new_pop_input_ids = tuple(population_input_ids[i].cpu().tolist()) assert len(new_pop_input_ids) == (self._num_tokens) self._prefix_pool.update( population_input_ids[i], all_candidate_losses[i], all_accuracy[i] ) return all_candidate_losses, all_candidate_n_correct def _create_full_text_ids( self, full_text_input_ids: torch.Tensor) -> torch.Tensor: """Creates input for generating explanation. Takes tokenized inputs (like: "Input: 7 8 Output: 15") and makes a full string that looks like "Data:\n\n Input: .... 15 \n\nExplanation:\n\n", using whatever template is defined by pre-data and post-data. """ B = len(full_text_input_ids) pre_data = self._pre_data_token_ids.repeat((B, 1)).to(device) # Like "Data:\n\n" post_data = self._post_data_token_ids.repeat((B, 1)).to(device) # Like "\n\nPrompt:" output = torch.cat((pre_data, full_text_input_ids, post_data), dim=1) return output def compute_loss_and_call_backward( self, x_tokenized: transformers.BatchEncoding, y_tokenized: transformers.BatchEncoding, possible_answer_mask: torch.Tensor, full_text_tokenized: Optional[transformers.BatchEncoding] = None ) -> Tuple[torch.Tensor, int]: """Returns the loss from the best example in the population Note: does not call loss.backward() Returns: loss (float torch.Tensor) -- the loss num_correct (int): number of examples where prediction was correct """ self.model.eval() # allow for conditioning only on x or y. This is mainly just used for ablations. if self.conditioning_strategy == "x_only": full_text_tokenized = y_tokenized elif self.conditioning_strategy == "y_only": full_text_tokenized = y_tokenized elif self.conditioning_strategy == "unconditional": full_text_tokenized['input_ids'] = torch.full( size=(len(y_tokenized), 1), fill_value=self.tokenizer.bos_token_id, device=device, ) full_text_tokenized['attention_mask'] = torch.ones_like( full_text_tokenized['input_ids'] ) # logic here is that we want to see a sample multiple times before # we actually have a good estimate of its loss. num_min_occurrences = 2 full_text_ids = self._create_full_text_ids( full_text_input_ids=full_text_tokenized.input_ids, ) self._initialize_pop_once(full_text_ids=full_text_ids) prefix_save_folder = os.path.join(self.args.save_dir_unique, 'prefix') df_to_print = self._prefix_pool.print(topk=10, min_occurrences=num_min_occurrences) os.makedirs(prefix_save_folder, exist_ok=True) log_prefixes = False if log_prefixes: prefix_out_file = os.path.join(prefix_save_folder, f'prefix_{self._step}.p') df_to_print.to_pickle(prefix_out_file) print(f'wrote {len(df_to_print)} prefixes to {prefix_out_file}') # Grab new population population_input_ids = self._get_population_and_random_generations( full_text_ids=full_text_ids, ) if self._num_mutations_per_ex > 0: mutated_population_input_ids = self._mutate( population_input_ids=population_input_ids, full_text_ids=full_text_ids ) full_population_input_ids = torch.cat( (population_input_ids, mutated_population_input_ids), dim=0 ) else: # Support skipping mutation step by stetting _num_mutations_per_ex to 0 full_population_input_ids = population_input_ids # Re-score new guys all_candidate_losses, all_candidate_n_correct = self._score_population( x_tokenized=x_tokenized, y_tokenized=y_tokenized, population_input_ids=full_population_input_ids, possible_answer_mask=possible_answer_mask ) # Track changes in population to enable early stopping. self._track_early_stopping() # Reset prefix IDs so that the model can be readily used for eval. best_prefix_ids = min(self._prefix_pool._avg_loss, key=self._prefix_pool._avg_loss.get) self._set_prefix_ids(torch.tensor(best_prefix_ids).to(device)) self.prefix_embedding.requires_grad = False self._step += 1 return all_candidate_losses.min(), all_candidate_n_correct.max() def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: # # Get candidate IDs for every position. # pass
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Ancestors
- AutoPrompt
- HotFlip
- PrefixModel
- torch.nn.modules.module.Module
- abc.ABC
Methods
def compute_loss_and_call_backward(self,
x_tokenized: transformers.tokenization_utils_base.BatchEncoding,
y_tokenized: transformers.tokenization_utils_base.BatchEncoding,
possible_answer_mask: torch.Tensor,
full_text_tokenized: transformers.tokenization_utils_base.BatchEncoding | None = None) ‑> Tuple[torch.Tensor, int]-
Expand source code
def compute_loss_and_call_backward( self, x_tokenized: transformers.BatchEncoding, y_tokenized: transformers.BatchEncoding, possible_answer_mask: torch.Tensor, full_text_tokenized: Optional[transformers.BatchEncoding] = None ) -> Tuple[torch.Tensor, int]: """Returns the loss from the best example in the population Note: does not call loss.backward() Returns: loss (float torch.Tensor) -- the loss num_correct (int): number of examples where prediction was correct """ self.model.eval() # allow for conditioning only on x or y. This is mainly just used for ablations. if self.conditioning_strategy == "x_only": full_text_tokenized = y_tokenized elif self.conditioning_strategy == "y_only": full_text_tokenized = y_tokenized elif self.conditioning_strategy == "unconditional": full_text_tokenized['input_ids'] = torch.full( size=(len(y_tokenized), 1), fill_value=self.tokenizer.bos_token_id, device=device, ) full_text_tokenized['attention_mask'] = torch.ones_like( full_text_tokenized['input_ids'] ) # logic here is that we want to see a sample multiple times before # we actually have a good estimate of its loss. num_min_occurrences = 2 full_text_ids = self._create_full_text_ids( full_text_input_ids=full_text_tokenized.input_ids, ) self._initialize_pop_once(full_text_ids=full_text_ids) prefix_save_folder = os.path.join(self.args.save_dir_unique, 'prefix') df_to_print = self._prefix_pool.print(topk=10, min_occurrences=num_min_occurrences) os.makedirs(prefix_save_folder, exist_ok=True) log_prefixes = False if log_prefixes: prefix_out_file = os.path.join(prefix_save_folder, f'prefix_{self._step}.p') df_to_print.to_pickle(prefix_out_file) print(f'wrote {len(df_to_print)} prefixes to {prefix_out_file}') # Grab new population population_input_ids = self._get_population_and_random_generations( full_text_ids=full_text_ids, ) if self._num_mutations_per_ex > 0: mutated_population_input_ids = self._mutate( population_input_ids=population_input_ids, full_text_ids=full_text_ids ) full_population_input_ids = torch.cat( (population_input_ids, mutated_population_input_ids), dim=0 ) else: # Support skipping mutation step by stetting _num_mutations_per_ex to 0 full_population_input_ids = population_input_ids # Re-score new guys all_candidate_losses, all_candidate_n_correct = self._score_population( x_tokenized=x_tokenized, y_tokenized=y_tokenized, population_input_ids=full_population_input_ids, possible_answer_mask=possible_answer_mask ) # Track changes in population to enable early stopping. self._track_early_stopping() # Reset prefix IDs so that the model can be readily used for eval. best_prefix_ids = min(self._prefix_pool._avg_loss, key=self._prefix_pool._avg_loss.get) self._set_prefix_ids(torch.tensor(best_prefix_ids).to(device)) self.prefix_embedding.requires_grad = False self._step += 1 return all_candidate_losses.min(), all_candidate_n_correct.max()
Returns the loss from the best example in the population
Note: does not call loss.backward()
Returns
loss (float torch.Tensor) – the loss num_correct (int): number of examples where prediction was correct
def post_epoch(self,
dataloader: torch.utils.data.dataloader.DataLoader,
possible_answer_mask: torch.Tensor) ‑> None-
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: # # Get candidate IDs for every position. # pass
Inherited members