Module imodelsx.iprompt.gumbel
Expand source code
from typing import Iterable, Optional, Tuple
import argparse
import torch
import torch.nn as nn
import transformers
from imodelsx.iprompt.utils import PrefixLoss, PrefixModel
class GumbelPrefixModel(PrefixModel):
args: argparse.Namespace
loss_func: PrefixLoss
model: transformers.PreTrainedModel
tokenizer: transformers.PreTrainedTokenizer
prefix_embedding: nn.Parameter
def __init__(self, args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix: str):
super().__init__(args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix)
self.word_weights = nn.Parameter(
torch.randn((1, args.num_learned_tokens, self.vocab_size)), requires_grad=True
)
# TODO: argparse for tau
# (lower tau -> more spiky)
self.tau = 10
# TODO: argparse for tau_anneal
self.tau_anneal = 1.02
@property
def trainable_params(self) -> Iterable[nn.Parameter]:
return [self.word_weights]
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None:
self.tau = self.tau / self.tau_anneal
print(f"𝛕 = {self.tau:.2f}")
def embed_input_ids(self, input_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
assert prefix_ids is None, "cannot provide custom prefix IDs for Gumbel"
# word_weights_dist = (word_weights * 1000).softmax(dim=-1)
prefix_embedding_words_dist = nn.functional.gumbel_softmax(
self.word_weights.repeat((len(input_ids), 1, 1)), tau=self.tau, dim=-1, hard=False
)
print(
"trying words:", self.tokenizer.decode(
prefix_embedding_words_dist[0].argmax(dim=1).tolist()),
"with prob", prefix_embedding_words_dist[0].max().item()
)
prefix_embedding = prefix_embedding_words_dist @ self.token_embedding.weight
input_ids = torch.cat(
(prefix_embedding_words_dist.argmax(dim=-1), input_ids), dim=1
)
outputs = torch.cat(
# concatenate prefix + example
(prefix_embedding, self.token_embedding.forward(input_ids)), dim=1
)
return input_ids, outputs
Classes
class GumbelPrefixModel (args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.modeling_utils.PreTrainedModel, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, preprefix: str)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.Note
As per the example above, an
__init__()
call to the parent class must be made before assignment on the child.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class GumbelPrefixModel(PrefixModel): args: argparse.Namespace loss_func: PrefixLoss model: transformers.PreTrainedModel tokenizer: transformers.PreTrainedTokenizer prefix_embedding: nn.Parameter def __init__(self, args: argparse.Namespace, loss_func: PrefixLoss, model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, preprefix: str): super().__init__(args=args, loss_func=loss_func, model=model, tokenizer=tokenizer, preprefix=preprefix) self.word_weights = nn.Parameter( torch.randn((1, args.num_learned_tokens, self.vocab_size)), requires_grad=True ) # TODO: argparse for tau # (lower tau -> more spiky) self.tau = 10 # TODO: argparse for tau_anneal self.tau_anneal = 1.02 @property def trainable_params(self) -> Iterable[nn.Parameter]: return [self.word_weights] def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: self.tau = self.tau / self.tau_anneal print(f"𝛕 = {self.tau:.2f}") def embed_input_ids(self, input_ids: torch.Tensor, prefix_ids: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: assert prefix_ids is None, "cannot provide custom prefix IDs for Gumbel" # word_weights_dist = (word_weights * 1000).softmax(dim=-1) prefix_embedding_words_dist = nn.functional.gumbel_softmax( self.word_weights.repeat((len(input_ids), 1, 1)), tau=self.tau, dim=-1, hard=False ) print( "trying words:", self.tokenizer.decode( prefix_embedding_words_dist[0].argmax(dim=1).tolist()), "with prob", prefix_embedding_words_dist[0].max().item() ) prefix_embedding = prefix_embedding_words_dist @ self.token_embedding.weight input_ids = torch.cat( (prefix_embedding_words_dist.argmax(dim=-1), input_ids), dim=1 ) outputs = torch.cat( # concatenate prefix + example (prefix_embedding, self.token_embedding.forward(input_ids)), dim=1 ) return input_ids, outputs
Ancestors
- PrefixModel
- torch.nn.modules.module.Module
- abc.ABC
Class variables
var args : argparse.Namespace
var loss_func : PrefixLoss
var model : transformers.modeling_utils.PreTrainedModel
var prefix_embedding : torch.nn.parameter.Parameter
var tokenizer : transformers.tokenization_utils.PreTrainedTokenizer
Instance variables
var trainable_params : Iterable[torch.nn.parameter.Parameter]
-
Expand source code
@property def trainable_params(self) -> Iterable[nn.Parameter]: return [self.word_weights]
Methods
def post_epoch(self, dataloader: torch.utils.data.dataloader.DataLoader, possible_answer_mask: torch.Tensor) ‑> None
-
Expand source code
def post_epoch(self, dataloader: torch.utils.data.DataLoader, possible_answer_mask: torch.Tensor) -> None: self.tau = self.tau / self.tau_anneal print(f"𝛕 = {self.tau:.2f}")
Inherited members