Module imodelsx.iprompt.hotflip

Classes

class HotFlip (args: argparse.Namespace,
loss_func: PrefixLoss,
model: transformers.modeling_utils.PreTrainedModel,
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer,
preprefix: str = '')
Expand source code
class HotFlip(PrefixModel):
    args: argparse.Namespace
    loss_func: PrefixLoss
    model: transformers.PreTrainedModel
    tokenizer: transformers.PreTrainedTokenizer
    prefix_ids: torch.Tensor
    prefix_embedding: nn.Parameter
    preprefix: str
    def __init__(
            self,
            args: argparse.Namespace,
            loss_func: PrefixLoss,
            model: transformers.PreTrainedModel,
            tokenizer: transformers.PreTrainedTokenizer,
            preprefix: str = ''
        ):
        super().__init__(
            args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix
        )
        # HotFlip-specific parameters.
        self._min_loss = float('inf')
        self._num_tokens = args.num_learned_tokens # TODO argparse for n_tokens
        self._num_candidates_per_prefix_token = args.hotflip_num_candidates # TODO argparse for this too
        self._swap_token_idx = 0

        self._tested_prefix_ids = collections.defaultdict(lambda: 0)
        # Sort both a version with a preprefix ("The function to compute is") and a version
        # where the full prefix is discovered by HotFlip without any assistance.
        preprefix_ids = [self.tokenizer.bos_token_id] if self.tokenizer.bos_token_id else []
        if preprefix:
            preprefix_ids.extend(self.tokenizer.encode(preprefix))
        self.preprefix_ids = torch.tensor(preprefix_ids, dtype=int).to(device)
        self.prefix_ids = None
        self._set_prefix_ids(
            self.init_discrete_prefix(num_tokens=self._num_tokens)
        )
        print(f"preprefix: '{preprefix}'")

        # disable grads to model
        for p in self.model.parameters(): p.requires_grad = False

        # track data specific to HotFlip
        self._epoch = 0
        self._data = []
        self._loss_for_prefix = {}
        # 
        self.prefix_before_input = args.prefix_before_input

    def check_early_stop(self) -> bool:
        """Allow prefix models to stop early."""
        if self.args.early_stopping_steps == -1:
            return False
        return self._steps_since_new_prefix >= self.args.early_stopping_steps
    
    def _set_prefix_ids(self, new_ids: torch.Tensor) -> None:
        assert new_ids.ndim == 1, "cannot set prefix with more than 1 dim (need list of IDs)"

        # Track steps since new prefix to enable early stopping
        if (self.prefix_ids is not None) and (self.prefix_ids == new_ids).all():
            self._steps_since_new_prefix += 1
        else:
            self._steps_since_new_prefix = 0
        

        self.prefix_ids = new_ids.to(device)
        self.prefix_embedding = nn.Parameter(
            self.token_embedding.to(device).forward(self.prefix_ids), requires_grad=True
        )
        # track prefixes we've tried
        self._tested_prefix_ids[(tuple(new_ids.flatten().tolist()), self._swap_token_idx)] += 1

    def pre_epoch(self) -> None:
        # Print closest tokens at the beginning of each epoch.
        if VERBOSE:
            print("*" *  30)
            print(f"Epoch {epoch}. Closest tokens to '{prefix_str}':")
            word_distances =  ((self.token_embedding.weight - self.prefix_embedding.reshape(1, emb_dim))**2).sum(1)
            assert word_distances.shape == (50_257,)
            topk_closest_words = word_distances.topk(k=TOP_K, largest=False)
            for _id, _dist in zip(topk_closest_words.indices.cpu().tolist(), topk_closest_words.values.cpu().tolist()):
                print(f'\t{self.id_to_word[_id]} ({_id}): {_dist:.3f}')
            print("*" * 30)
    
    @property
    def _prefix_token_grad(self) -> torch.Tensor:
        """Gradient of the prefix tokens wrt the token embedding matrix."""
        return torch.einsum('nd,vd->nv', self.prefix_embedding.grad, self.token_embedding.weight)
    
    def compute_loss_and_call_backward(
            self,
            x_tokenized: transformers.BatchEncoding,
            y_tokenized: transformers.BatchEncoding,
            possible_answer_mask: torch.Tensor,
            full_text_tokenized: Optional[transformers.BatchEncoding] = None
        ) -> Tuple[torch.Tensor, int]:
        """Computes loss using `self.loss_func`.
        
        Returns:
            loss (float torch.Tensor) -- the loss
            num_correct (int): number of examples where prediction was correct
        """
        original_input_ids = x_tokenized.input_ids
        next_token_ids = y_tokenized.input_ids # only compute loss over next token

        _input_ids, loss, n_correct = self._compute_loss_with_set_prefix(
            original_input_ids=original_input_ids,
            next_token_ids=next_token_ids, # only compute loss over next token
            possible_answer_mask=possible_answer_mask
        )

        loss.backward()

        # self._set_prefix_ids(best_prefix)
        return loss, n_correct
        
    def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
        # 
        # Get candidate IDs for every position.
        # 
        token_idx = self._swap_token_idx
        token_grads = self._prefix_token_grad
        top_tokens_per_position = (
            token_grads.topk(k=self._num_candidates_per_prefix_token, dim=1, largest=False).indices
        )
        assert top_tokens_per_position.shape == (self._num_tokens, self._num_candidates_per_prefix_token)

        top_swap_tokens = top_tokens_per_position[token_idx, :]
        #
        # Get most likely tokens.
        #
        prefix_until_swap_ids = torch.cat(
            (self.preprefix_ids.to(device), self.prefix_ids[:token_idx].to(device)), dim=0
        )[None].to(device)
        with torch.no_grad():
            all_preprefix_logits = self.model(prefix_until_swap_ids)
            swap_token_logits = all_preprefix_logits.logits[:, -1, :]

        rvocab = {v: k for k,v in self.tokenizer.vocab.items()}
        # dist_sum = (swap_token_logits.log_softmax(dim=1) * .7 + (-1 * token_grads).log_softmax(dim=1))
        # for v in (swap_token_logits.log_softmax(dim=1) * .7 + (-1 * token_grads).log_softmax(dim=1)).topk(10).indices.flatten(): print(rvocab[v.item()])

        alpha = 0.0 # TODO argparse for this alpha
        print(f"HotFlip alpha = {alpha}")
        token_losses = (
            (swap_token_logits.log_softmax(dim=1) * alpha + (-1 * token_grads).log_softmax(dim=1))
        )
        top_swap_tokens = token_losses.argsort(descending=True).flatten()

        # if we've already tried this (prefix, swap_token_idx) combo, then let's try the next n candidates.
        _n = self._tested_prefix_ids[tuple(self.prefix_ids.flatten().tolist()), token_idx] - 1
        assert _n >= 0, "something went wrong"
        top_swap_tokens = top_swap_tokens[(_n * self._num_candidates_per_prefix_token) : (_n+1) * self._num_candidates_per_prefix_token]
        # 
        # Evaluate candidates.
        # 
        all_candidate_losses = torch.zeros(self._num_candidates_per_prefix_token, dtype=float).to(device)
        all_n_correct = torch.zeros(self._num_candidates_per_prefix_token, dtype=int).to(device)
        best_loss = self._min_loss

        mask = torch.nn.functional.one_hot(
            torch.tensor(token_idx), num_classes=self._num_tokens
        ).bool().to(device)

        # Evaluate each prefix.
        for batch in tqdm.tqdm(dataloader, desc='evaluating HotFlip candidates', colour='red', leave=False):
            # Loop in this order so we only tokenize each thing once.
            x_text, y_text = self.prepare_batch(batch=batch)
            input_ids = self.tokenizer(x_text, return_tensors='pt', padding='longest')['input_ids'].to(device)
            next_token_ids = self.tokenizer(y_text, return_tensors='pt', padding='longest')['input_ids'].to(device)
            # only evaluate on single next-token
            next_token_ids = next_token_ids[:, 0]
            for candidate_idx in range(self._num_candidates_per_prefix_token):
                new_token_id = top_swap_tokens[candidate_idx]
                prefix_ids = torch.where(
                    mask, new_token_id, self.prefix_ids.to(device)
                ).to(device)
                with torch.no_grad():
                    _input_ids, loss, n_correct = (
                        self._compute_loss_with_set_prefix(
                            original_input_ids=input_ids,
                            next_token_ids=next_token_ids,
                            possible_answer_mask=possible_answer_mask,
                            prefix_ids=prefix_ids
                        )
                    )
                all_candidate_losses[candidate_idx] += loss
                all_n_correct[candidate_idx] += n_correct

        ##################################################################################################################
        hotflip_out_path = os.path.join(self.args.save_dir_unique, 'hotflip_grads_data.p')
        for _i in range(self._num_candidates_per_prefix_token):
            token_id = top_swap_tokens[_i].item()
            # rank, prefix, token_id, token_grad, loss_with_this_token, n_correct_with_this_token
            self._data.append(
                (_i, self.prefix_ids.tolist(), token_id, token_grads.flatten()[token_id].item(), all_candidate_losses[_i].item(), all_n_correct[_i].item())
            )
        pickle.dump(self._data, open(hotflip_out_path, 'wb'))
        ##################################################################################################################

        #
        # Collect losses for all prefixes. Then set prefix to best one we haven't seen before.
        #
        for candidate_idx in range(self._num_candidates_per_prefix_token):
            new_token_id = top_swap_tokens[candidate_idx]
            prefix_ids = tuple(
                torch.where(
                    mask, new_token_id, self.prefix_ids.to(device)
                ).tolist()
            )
            self._loss_for_prefix[prefix_ids] = (
                all_candidate_losses[candidate_idx].item(),
                all_n_correct[candidate_idx].item()
            )
        
        # next prefix is the one we know about with the min loss that we haven't tried
        # so far.
        best_prefix_ids = min(self._loss_for_prefix, key=lambda p: self._loss_for_prefix.get(p)[0])
        best_loss, best_n_correct =  self._loss_for_prefix[best_prefix_ids]

        # if loss < self._min_loss:
        #     self._min_loss = loss
        #     best_prefix_ids = prefix_ids

        # 
        # Pick top candidate and reset self._min_loss. (TODO: Support beam width > 1.)
        # 
        old_prefix_str = self.tokenizer.decode(self.preprefix_ids.tolist() + self.prefix_ids.tolist())
        new_prefix_str = self.tokenizer.decode(self.preprefix_ids.tolist() + list(best_prefix_ids))
        print(f'[Loss = {best_loss/len(dataloader):.2f}] // Old prefix: {old_prefix_str} // New prefix: {new_prefix_str} // New n_correct = {best_n_correct}')

        self._swap_token_idx = (self._swap_token_idx + 1) % self._num_tokens
        # self._swap_token_idx = random.randint(0, (self._num_tokens-1))

        self._set_prefix_ids(torch.tensor(best_prefix_ids))

        return

    @property
    def prefix_embedding_token_ids(self) -> torch.Tensor:
        return self.prefix_embedding.argmax(dim=-1)

    @property
    def trainable_params(self) -> Iterable[nn.Parameter]:
        return [self.prefix_embedding]

    def embed_input_ids(
        self, input_ids: torch.Tensor, next_token_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Gets token embeddings for tokens given by `input_ids` prefixed by `prefix_ids`.

        If not provided, `prefix_ids` is replaced with `self.prefix_ids`
        at every position.

        Args:
            input_ids (int torch.Tensor) -- IDs for batch of sentences
            prefix_ids (Optional int torch.Tensor) -- IDs for a single prefix
                to be prepended before each input ID. If not provided,
                will be overridden with prefix from `self.prefix_ids`.

        Returns:
            input_ids (int torch.Tensor) -- IDs of all tokens, including prefix
            outputs (float torch.Tensor): embedded tokens
        """
        batch_size = len(input_ids)
        if prefix_ids is None:
            prefix_ids = self.prefix_ids
            prefix_embedding = self.prefix_embedding
            
        else:
            prefix_embedding = self.token_embedding.forward(prefix_ids)

        # concatenate preprefix (fixed) + prefix (learned) + example
        prefix_ids = prefix_ids[None].to(device).repeat((batch_size, 1)).to(device)
        preprefix_ids = self.preprefix_ids[None].to(device).repeat((batch_size, 1)).to(device)

        if self.prefix_before_input:
            full_input_ids = torch.cat(
                (preprefix_ids, prefix_ids, input_ids, next_token_ids), dim=1
            )
            outputs = torch.cat(
                (
                    self.token_embedding.forward(preprefix_ids),
                    prefix_embedding[None].repeat((batch_size, 1, 1)),
                    self.token_embedding.forward(input_ids),
                    self.token_embedding.forward(next_token_ids),
                ), dim=1
            )
        else:
            full_input_ids = torch.cat(
                (input_ids, preprefix_ids, prefix_ids, next_token_ids), dim=1
            )
            outputs = torch.cat(
                (
                    self.token_embedding.forward(input_ids),
                    self.token_embedding.forward(preprefix_ids),
                    prefix_embedding[None].repeat((batch_size, 1, 1)),
                    self.token_embedding.forward(next_token_ids),
                ), dim=1
            )
        return full_input_ids, outputs

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Ancestors

Subclasses

Class variables

var args : argparse.Namespace
var loss_funcPrefixLoss
var model : transformers.modeling_utils.PreTrainedModel
var prefix_embedding : torch.nn.parameter.Parameter
var prefix_ids : torch.Tensor
var preprefix : str
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer

Instance variables

prop prefix_embedding_token_ids : torch.Tensor
Expand source code
@property
def prefix_embedding_token_ids(self) -> torch.Tensor:
    return self.prefix_embedding.argmax(dim=-1)
prop trainable_params : Iterable[torch.nn.parameter.Parameter]
Expand source code
@property
def trainable_params(self) -> Iterable[nn.Parameter]:
    return [self.prefix_embedding]

Methods

def embed_input_ids(self,
input_ids: torch.Tensor,
next_token_ids: torch.Tensor,
prefix_ids: torch.Tensor | None) ‑> Tuple[torch.Tensor, torch.Tensor]
Expand source code
def embed_input_ids(
    self, input_ids: torch.Tensor, next_token_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
    """Gets token embeddings for tokens given by `input_ids` prefixed by `prefix_ids`.

    If not provided, `prefix_ids` is replaced with `self.prefix_ids`
    at every position.

    Args:
        input_ids (int torch.Tensor) -- IDs for batch of sentences
        prefix_ids (Optional int torch.Tensor) -- IDs for a single prefix
            to be prepended before each input ID. If not provided,
            will be overridden with prefix from `self.prefix_ids`.

    Returns:
        input_ids (int torch.Tensor) -- IDs of all tokens, including prefix
        outputs (float torch.Tensor): embedded tokens
    """
    batch_size = len(input_ids)
    if prefix_ids is None:
        prefix_ids = self.prefix_ids
        prefix_embedding = self.prefix_embedding
        
    else:
        prefix_embedding = self.token_embedding.forward(prefix_ids)

    # concatenate preprefix (fixed) + prefix (learned) + example
    prefix_ids = prefix_ids[None].to(device).repeat((batch_size, 1)).to(device)
    preprefix_ids = self.preprefix_ids[None].to(device).repeat((batch_size, 1)).to(device)

    if self.prefix_before_input:
        full_input_ids = torch.cat(
            (preprefix_ids, prefix_ids, input_ids, next_token_ids), dim=1
        )
        outputs = torch.cat(
            (
                self.token_embedding.forward(preprefix_ids),
                prefix_embedding[None].repeat((batch_size, 1, 1)),
                self.token_embedding.forward(input_ids),
                self.token_embedding.forward(next_token_ids),
            ), dim=1
        )
    else:
        full_input_ids = torch.cat(
            (input_ids, preprefix_ids, prefix_ids, next_token_ids), dim=1
        )
        outputs = torch.cat(
            (
                self.token_embedding.forward(input_ids),
                self.token_embedding.forward(preprefix_ids),
                prefix_embedding[None].repeat((batch_size, 1, 1)),
                self.token_embedding.forward(next_token_ids),
            ), dim=1
        )
    return full_input_ids, outputs

Gets token embeddings for tokens given by input_ids prefixed by prefix_ids.

If not provided, prefix_ids is replaced with self.prefix_ids at every position.

Args

input_ids (int torch.Tensor) – IDs for batch of sentences prefix_ids (Optional int torch.Tensor) – IDs for a single prefix to be prepended before each input ID. If not provided, will be overridden with prefix from self.prefix_ids.

Returns

input_ids (int torch.Tensor) – IDs of all tokens, including prefix outputs (float torch.Tensor): embedded tokens

def post_epoch(self,
dataloader: torch.utils.data.dataloader.DataLoader,
possible_answer_mask: torch.Tensor) ‑> None
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
    # 
    # Get candidate IDs for every position.
    # 
    token_idx = self._swap_token_idx
    token_grads = self._prefix_token_grad
    top_tokens_per_position = (
        token_grads.topk(k=self._num_candidates_per_prefix_token, dim=1, largest=False).indices
    )
    assert top_tokens_per_position.shape == (self._num_tokens, self._num_candidates_per_prefix_token)

    top_swap_tokens = top_tokens_per_position[token_idx, :]
    #
    # Get most likely tokens.
    #
    prefix_until_swap_ids = torch.cat(
        (self.preprefix_ids.to(device), self.prefix_ids[:token_idx].to(device)), dim=0
    )[None].to(device)
    with torch.no_grad():
        all_preprefix_logits = self.model(prefix_until_swap_ids)
        swap_token_logits = all_preprefix_logits.logits[:, -1, :]

    rvocab = {v: k for k,v in self.tokenizer.vocab.items()}
    # dist_sum = (swap_token_logits.log_softmax(dim=1) * .7 + (-1 * token_grads).log_softmax(dim=1))
    # for v in (swap_token_logits.log_softmax(dim=1) * .7 + (-1 * token_grads).log_softmax(dim=1)).topk(10).indices.flatten(): print(rvocab[v.item()])

    alpha = 0.0 # TODO argparse for this alpha
    print(f"HotFlip alpha = {alpha}")
    token_losses = (
        (swap_token_logits.log_softmax(dim=1) * alpha + (-1 * token_grads).log_softmax(dim=1))
    )
    top_swap_tokens = token_losses.argsort(descending=True).flatten()

    # if we've already tried this (prefix, swap_token_idx) combo, then let's try the next n candidates.
    _n = self._tested_prefix_ids[tuple(self.prefix_ids.flatten().tolist()), token_idx] - 1
    assert _n >= 0, "something went wrong"
    top_swap_tokens = top_swap_tokens[(_n * self._num_candidates_per_prefix_token) : (_n+1) * self._num_candidates_per_prefix_token]
    # 
    # Evaluate candidates.
    # 
    all_candidate_losses = torch.zeros(self._num_candidates_per_prefix_token, dtype=float).to(device)
    all_n_correct = torch.zeros(self._num_candidates_per_prefix_token, dtype=int).to(device)
    best_loss = self._min_loss

    mask = torch.nn.functional.one_hot(
        torch.tensor(token_idx), num_classes=self._num_tokens
    ).bool().to(device)

    # Evaluate each prefix.
    for batch in tqdm.tqdm(dataloader, desc='evaluating HotFlip candidates', colour='red', leave=False):
        # Loop in this order so we only tokenize each thing once.
        x_text, y_text = self.prepare_batch(batch=batch)
        input_ids = self.tokenizer(x_text, return_tensors='pt', padding='longest')['input_ids'].to(device)
        next_token_ids = self.tokenizer(y_text, return_tensors='pt', padding='longest')['input_ids'].to(device)
        # only evaluate on single next-token
        next_token_ids = next_token_ids[:, 0]
        for candidate_idx in range(self._num_candidates_per_prefix_token):
            new_token_id = top_swap_tokens[candidate_idx]
            prefix_ids = torch.where(
                mask, new_token_id, self.prefix_ids.to(device)
            ).to(device)
            with torch.no_grad():
                _input_ids, loss, n_correct = (
                    self._compute_loss_with_set_prefix(
                        original_input_ids=input_ids,
                        next_token_ids=next_token_ids,
                        possible_answer_mask=possible_answer_mask,
                        prefix_ids=prefix_ids
                    )
                )
            all_candidate_losses[candidate_idx] += loss
            all_n_correct[candidate_idx] += n_correct

    ##################################################################################################################
    hotflip_out_path = os.path.join(self.args.save_dir_unique, 'hotflip_grads_data.p')
    for _i in range(self._num_candidates_per_prefix_token):
        token_id = top_swap_tokens[_i].item()
        # rank, prefix, token_id, token_grad, loss_with_this_token, n_correct_with_this_token
        self._data.append(
            (_i, self.prefix_ids.tolist(), token_id, token_grads.flatten()[token_id].item(), all_candidate_losses[_i].item(), all_n_correct[_i].item())
        )
    pickle.dump(self._data, open(hotflip_out_path, 'wb'))
    ##################################################################################################################

    #
    # Collect losses for all prefixes. Then set prefix to best one we haven't seen before.
    #
    for candidate_idx in range(self._num_candidates_per_prefix_token):
        new_token_id = top_swap_tokens[candidate_idx]
        prefix_ids = tuple(
            torch.where(
                mask, new_token_id, self.prefix_ids.to(device)
            ).tolist()
        )
        self._loss_for_prefix[prefix_ids] = (
            all_candidate_losses[candidate_idx].item(),
            all_n_correct[candidate_idx].item()
        )
    
    # next prefix is the one we know about with the min loss that we haven't tried
    # so far.
    best_prefix_ids = min(self._loss_for_prefix, key=lambda p: self._loss_for_prefix.get(p)[0])
    best_loss, best_n_correct =  self._loss_for_prefix[best_prefix_ids]

    # if loss < self._min_loss:
    #     self._min_loss = loss
    #     best_prefix_ids = prefix_ids

    # 
    # Pick top candidate and reset self._min_loss. (TODO: Support beam width > 1.)
    # 
    old_prefix_str = self.tokenizer.decode(self.preprefix_ids.tolist() + self.prefix_ids.tolist())
    new_prefix_str = self.tokenizer.decode(self.preprefix_ids.tolist() + list(best_prefix_ids))
    print(f'[Loss = {best_loss/len(dataloader):.2f}] // Old prefix: {old_prefix_str} // New prefix: {new_prefix_str} // New n_correct = {best_n_correct}')

    self._swap_token_idx = (self._swap_token_idx + 1) % self._num_tokens
    # self._swap_token_idx = random.randint(0, (self._num_tokens-1))

    self._set_prefix_ids(torch.tensor(best_prefix_ids))

    return
def pre_epoch(self) ‑> None
Expand source code
def pre_epoch(self) -> None:
    # Print closest tokens at the beginning of each epoch.
    if VERBOSE:
        print("*" *  30)
        print(f"Epoch {epoch}. Closest tokens to '{prefix_str}':")
        word_distances =  ((self.token_embedding.weight - self.prefix_embedding.reshape(1, emb_dim))**2).sum(1)
        assert word_distances.shape == (50_257,)
        topk_closest_words = word_distances.topk(k=TOP_K, largest=False)
        for _id, _dist in zip(topk_closest_words.indices.cpu().tolist(), topk_closest_words.values.cpu().tolist()):
            print(f'\t{self.id_to_word[_id]} ({_id}): {_dist:.3f}')
        print("*" * 30)

Inherited members