Module imodelsx.iprompt.gumbel

Expand source code
from typing import Iterable, Optional, Tuple

import argparse

import torch
import torch.nn as nn
import transformers

from imodelsx.iprompt.utils import PrefixLoss, PrefixModel


class GumbelPrefixModel(PrefixModel):
    args: argparse.Namespace
    loss_func: PrefixLoss
    model: transformers.PreTrainedModel
    tokenizer: transformers.PreTrainedTokenizer
    prefix_embedding: nn.Parameter

    def __init__(self, args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix: str):
        super().__init__(args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix)
        self.word_weights = nn.Parameter(
            torch.randn((1, args.num_learned_tokens, self.vocab_size)), requires_grad=True
        )
        # TODO: argparse for tau
        # (lower tau -> more spiky)
        self.tau = 10
        # TODO: argparse for tau_anneal
        self.tau_anneal = 1.02

    @property
    def trainable_params(self) -> Iterable[nn.Parameter]:
        return [self.word_weights]
    
    def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
        self.tau = self.tau / self.tau_anneal
        print(f"𝛕 = {self.tau:.2f}")

    def embed_input_ids(self, input_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        assert prefix_ids is None, "cannot provide custom prefix IDs for Gumbel"
        # word_weights_dist = (word_weights * 1000).softmax(dim=-1)
        prefix_embedding_words_dist = nn.functional.gumbel_softmax(
            self.word_weights.repeat((len(input_ids), 1, 1)), tau=self.tau, dim=-1, hard=False
        )
        
        print(
            "trying words:", self.tokenizer.decode(
                prefix_embedding_words_dist[0].argmax(dim=1).tolist()),
            "with prob", prefix_embedding_words_dist[0].max().item()
        )
        prefix_embedding = prefix_embedding_words_dist @ self.token_embedding.weight

        input_ids = torch.cat(
            (prefix_embedding_words_dist.argmax(dim=-1), input_ids), dim=1
        )
        outputs = torch.cat(
            # concatenate prefix + example
            (prefix_embedding, self.token_embedding.forward(input_ids)), dim=1
        )
        return input_ids, outputs

Classes

class GumbelPrefixModel (args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.modeling_utils.PreTrainedModel, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, preprefix: str)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Expand source code
class GumbelPrefixModel(PrefixModel):
    args: argparse.Namespace
    loss_func: PrefixLoss
    model: transformers.PreTrainedModel
    tokenizer: transformers.PreTrainedTokenizer
    prefix_embedding: nn.Parameter

    def __init__(self, args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix: str):
        super().__init__(args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix)
        self.word_weights = nn.Parameter(
            torch.randn((1, args.num_learned_tokens, self.vocab_size)), requires_grad=True
        )
        # TODO: argparse for tau
        # (lower tau -> more spiky)
        self.tau = 10
        # TODO: argparse for tau_anneal
        self.tau_anneal = 1.02

    @property
    def trainable_params(self) -> Iterable[nn.Parameter]:
        return [self.word_weights]
    
    def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
        self.tau = self.tau / self.tau_anneal
        print(f"𝛕 = {self.tau:.2f}")

    def embed_input_ids(self, input_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        assert prefix_ids is None, "cannot provide custom prefix IDs for Gumbel"
        # word_weights_dist = (word_weights * 1000).softmax(dim=-1)
        prefix_embedding_words_dist = nn.functional.gumbel_softmax(
            self.word_weights.repeat((len(input_ids), 1, 1)), tau=self.tau, dim=-1, hard=False
        )
        
        print(
            "trying words:", self.tokenizer.decode(
                prefix_embedding_words_dist[0].argmax(dim=1).tolist()),
            "with prob", prefix_embedding_words_dist[0].max().item()
        )
        prefix_embedding = prefix_embedding_words_dist @ self.token_embedding.weight

        input_ids = torch.cat(
            (prefix_embedding_words_dist.argmax(dim=-1), input_ids), dim=1
        )
        outputs = torch.cat(
            # concatenate prefix + example
            (prefix_embedding, self.token_embedding.forward(input_ids)), dim=1
        )
        return input_ids, outputs

Ancestors

Class variables

var args : argparse.Namespace
var loss_funcPrefixLoss
var model : transformers.modeling_utils.PreTrainedModel
var prefix_embedding : torch.nn.parameter.Parameter
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer

Instance variables

var trainable_params : Iterable[torch.nn.parameter.Parameter]
Expand source code
@property
def trainable_params(self) -> Iterable[nn.Parameter]:
    return [self.word_weights]

Methods

def post_epoch(self, dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> None
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
    self.tau = self.tau / self.tau_anneal
    print(f"𝛕 = {self.tau:.2f}")

Inherited members