Module imodelsx.iprompt.api

Expand source code
from typing import Callable, Dict, List, Tuple

import datasets
import functools
import os
import random
import string
import numpy as np
import time
import torch
import transformers
import argparse
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from tqdm import tqdm
from collections import defaultdict
from imodelsx.iprompt import (
    AutoPrompt, iPrompt,
    PrefixLoss, PrefixModel,
    PromptTunedModel, HotFlip, GumbelPrefixModel
)
from imodelsx.iprompt.llm import get_llm
import pandas as pd
import logging
import pickle as pkl
from torch.utils.data import DataLoader
from datetime import datetime


"""
Explaining Patterns in Data with Language Models via Interpretable Autoprompting

Chandan Singh*, John X. Morris*, Jyoti Aneja, Alexander M. Rush, Jianfeng Gao
https://arxiv.org/abs/2210.01848
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


model_cls_dict = {
    'autoprompt': AutoPrompt,
    'iprompt': iPrompt,
    'gumbel': GumbelPrefixModel,
    'hotflip': HotFlip,
    'prompt_tune': PromptTunedModel,
}

def get_prompts_api(
        data: List[str], 
        llm: Callable,
        prompt_template: str, 
    ):
    data_str = random.choice(data)
    prompt = prompt_template(data=data_str).strip()
    answer = llm(prompt, max_new_tokens=24)
    return [answer]


def run_iprompt_api(
    r: Dict[str, List],
    input_strs: List[str],
    output_strs: List[str],
    model: PrefixModel,
    tokenizer: transformers.PreTrainedTokenizer,
    llm_api: str,
    save_dir: str = 'results',
    lr: float = 1e-4,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    verbose: int = 0,
):
    """
    Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

    Params
    ------
    r: dict
        dictionary of things to save
    """


    # remove periods and newlines from the output so we actually use the tokens
    # for the reranking step in iPrompt
    output_strs = [s.rstrip().rstrip('.') for s in output_strs]

    r['train_start_time'] = time.time()
    model.train()

    logging.info("beginning iPrompt with n_shots = %d", n_shots)

    assert len(input_strs) == len(
        output_strs), "input and output must be same length to create input-output pairs"
    text_strs = list(map(lambda t: f'{t[0]}{t[1]}.', zip(input_strs, output_strs)))
    df = pd.DataFrame.from_dict({
        'input': input_strs,
        'output': output_strs,
        'text': text_strs,
    })
    if n_shots > 1:
        d2 = defaultdict(list)
        for i in range(max_n_datapoints):
            all_shots = df.sample(n=n_shots, replace=False)
            d2['text'].append('\n\n'.join(all_shots['text'].values))
            #
            last_input = all_shots.tail(n=1)['input'].values[0]
            d2['input'].append(
                ''.join(all_shots['text'].values[:-1]) + last_input)
            d2['last_input'].append(last_input)
            #
            last_output = all_shots.tail(n=1)['output'].values[0]
            d2['output'].append(last_output)
            #
        df = pd.DataFrame.from_dict(d2)
    
    # shuffle rows
    if max_n_datapoints < len(df):
        df = df.sample(n=max_n_datapoints, replace=False)
    dset = datasets.Dataset.from_pandas(df)
    dset.shuffle()
    print(f'iPrompt got {len(dset)} datapoints, now loading model...')

    model = model.to(device)
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=True, drop_last=False)

    prompt_template = "{prompt_start}\n\n{data}\n\n{prompt_end}"
    prompt_template = functools.partial(
        prompt_template.format,
        prompt_start=model.llm_candidate_regeneration_prompt_start,
        prompt_end=model.llm_candidate_regeneration_prompt_end,
    )
    
    prompts = []
    # "gpt-3.5-turbo", "text-curie-001"
    llm = get_llm( 
        checkpoint=llm_api, role="user")
    stopping_early = False
    total_n = 0
    total_n_steps = 0
    total_n_datapoints = 0
    for epoch in range(n_epochs):       
        print(f'Beginning epoch {epoch}')
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for idx, batch in pbar:
            total_n_steps += 1
            if (n_shots > 1) and (single_shot_loss):
                batch['input'] = batch['last_input']
            x_text, y_text = model.prepare_batch(batch=batch)

            tok = functools.partial(
                model.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=max_length)
            text_tokenized = tok(batch['text']).to(device)
            text_detokenized = model.tokenizer.batch_decode(
                text_tokenized['input_ids'], 
                skip_special_tokens=True,
            )
            
            prompts.extend(
                get_prompts_api(
                    data=text_detokenized, 
                    llm=llm, 
                    prompt_template=prompt_template, 
                )
            )

            total_n += len(x_text)
            total_n_datapoints += len(x_text)
            if (total_n_datapoints > max_n_datapoints) or (total_n_steps > max_n_steps):
                stopping_early = True
                break

        if stopping_early:
            print(f"Ending epoch {epoch} early...")

        # save stuff
        for key, val in model.compute_metrics().items():
            r[key].append(val)

        # r['losses'].append(avg_loss)
        if epoch % epoch_save_interval == 0:
            os.makedirs(save_dir, exist_ok=True)
            pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

        # Early stopping, check after epoch
        if stopping_early:
            print(
                f"Stopping early after {total_n_steps} steps and {total_n_datapoints} datapoints")
            break

    
    # 
    #   Evaluate model on prefixes
    # 

    # Compute loss only over possible answers to make task easier
    possible_answer_ids = []
    for batch in dataloader:
        y_text = [answer for answer in batch['output']]
        y_tokenized = tokenizer(y_text, return_tensors='pt', padding='longest')
        # only test on the single next token
        true_next_token_ids = y_tokenized['input_ids'][:, 0]
        possible_answer_ids.extend(true_next_token_ids.tolist())
    
    possible_answer_ids = torch.tensor(possible_answer_ids)
    vocab_size = len(tokenizer.vocab)
    possible_answer_mask = (
            torch.arange(start=0, end=vocab_size)[:, None]
            ==
            possible_answer_ids[None, :]
        ).any(dim=1).to(device)
    n_eval = 256
    eval_dset = datasets.Dataset.from_dict(dset[:n_eval])
    eval_dataloader = DataLoader(
        eval_dset, batch_size=batch_size, shuffle=True, drop_last=False)   
    all_prefixes = model.tokenizer(
        [f" {prompt.strip()}" for prompt in prompts], 
        truncation=False, 
        padding=False,
    )["input_ids"]
    all_losses, all_accuracies = model.test_prefixes(
        prefixes=all_prefixes,
        eval_dataloader=eval_dataloader,
        possible_answer_mask=possible_answer_mask
    )

    # 
    #   Store prefix info 
    # 
    df = pd.DataFrame(
        zip(*[all_prefixes, all_losses, all_accuracies]),
        columns=['prefix', 'loss', 'accuracy']
    )
    df = df.sort_values(by=['accuracy', 'loss'], ascending=[
                        False, True]).reset_index()
    df = df.sort_values(by='accuracy', ascending=False).reset_index()

    df['prefix_str'] = df['prefix'].map(
        functools.partial(model.tokenizer.decode, skip_special_tokens=True)
    )

    print('Final prefixes')
    print(df.head())
    r.update({
            "prefix_ids": df['prefix'].tolist(),
            "prefixes": df['prefix_str'].tolist(),
            "prefix_train_acc": df['accuracy'].tolist(),
            "prefix_train_loss": df['loss'].tolist(),
        })

    r['train_end_time'] = time.time()
    r['train_time_elapsed'] = r['train_end_time'] - r['train_start_time']

    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

    return r


def run_iprompt_local(
    r: Dict[str, List],
    input_strs: List[str],
    output_strs: List[str],
    model: PrefixModel,
    tokenizer: transformers.PreTrainedTokenizer,
    save_dir: str = 'results',
    lr: float = 1e-4,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    verbose: int = 0,
):
    """
    Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

    Params
    ------
    r: dict
        dictionary of things to save
    """

    # remove periods and newlines from the output so we actually use the tokens
    # for the reranking step in iPrompt
    output_strs = [s.rstrip().rstrip('.') for s in output_strs]

    r['train_start_time'] = time.time()
    model.train()

    assert len(input_strs) == len(
        output_strs), "input and output must be same length to create input-output pairs"
    text_strs = list(map(lambda t: f'{t[0]}{t[1]}.', zip(input_strs, output_strs)))
    df = pd.DataFrame.from_dict({
        'input': input_strs,
        'output': output_strs,
        'text': text_strs,
    })
    if n_shots > 1:
        d2 = defaultdict(list)
        for i in range(max_n_datapoints):
            all_shots = df.sample(n=n_shots, replace=False)
            d2['text'].append('\n\n'.join(all_shots['text'].values))
            #
            last_input = all_shots.tail(n=1)['input'].values[0]
            d2['input'].append(
                ''.join(all_shots['text'].values[:-1]) + last_input)
            d2['last_input'].append(last_input)
            #
            last_output = all_shots.tail(n=1)['output'].values[0]
            d2['output'].append(last_output)
            #
        df = pd.DataFrame.from_dict(d2)
    # shuffle rows
    if max_n_datapoints < len(df):
        df = df.sample(n=max_n_datapoints, replace=False)
    dset = datasets.Dataset.from_pandas(df)
    dset.shuffle()
    print(f'iPrompt got {len(dset)} datapoints, now loading model...')

    model = model.to(device)
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=True, drop_last=False)

    # optimizer
    optim = torch.optim.AdamW(model.trainable_params, lr=lr)

    assert model.training

    # Compute loss only over possible answers to make task easier
    possible_answer_ids = []
    for batch in dataloader:
        y_text = [answer for answer in batch['output']]
        y_tokenized = tokenizer(y_text, return_tensors='pt', padding='longest')
        # only test on the single next token
        true_next_token_ids = y_tokenized['input_ids'][:, 0]
        possible_answer_ids.extend(true_next_token_ids.tolist())

    possible_answer_ids = torch.tensor(possible_answer_ids)
    num_unique_answers = len(set(possible_answer_ids.tolist()))
    assert num_unique_answers > 0, "need multiple answers for multiple choice"
    random_acc = 1 / num_unique_answers * 100.0
    majority_count = (
        possible_answer_ids[:, None] == possible_answer_ids[None, :]).sum(dim=1).max()
    majority_acc = majority_count * 100.0 / len(possible_answer_ids)
    print(
        f"Training with {num_unique_answers} possible answers / random acc {random_acc:.1f}% / majority acc {majority_acc:.1f}%")

    vocab_size = len(tokenizer.vocab)

    if mask_possible_answers:
        possible_answer_mask = (
            torch.arange(start=0, end=vocab_size)[:, None]
            ==
            possible_answer_ids[None, :]
        ).any(dim=1).to(device)
    else:
        possible_answer_mask = None

    stopping_early = False
    total_n_steps = 0
    total_n_datapoints = 0
    for epoch in range(n_epochs):
        model.pre_epoch()

        all_losses = []

        total_n = 0
        total_n_correct = 0
        print(f'Beginning epoch {epoch}')
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for idx, batch in pbar:
            total_n_steps += 1
            if (n_shots > 1) and (single_shot_loss):
                batch['input'] = batch['last_input']
            x_text, y_text = model.prepare_batch(batch=batch)

            tok = functools.partial(
                model.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=max_length)
            x_tokenized = tok(x_text).to(device)
            y_tokenized = tok(y_text).to(device)
            full_text_tokenized = tok(batch['text']).to(device)

            loss, n_correct = model.compute_loss_and_call_backward(
                x_tokenized=x_tokenized,
                y_tokenized=y_tokenized,
                possible_answer_mask=possible_answer_mask,
                full_text_tokenized=full_text_tokenized,
            )

            r["all_losses"].append(loss)
            r["all_n_correct"].append(n_correct)

            total_n += len(x_text)
            total_n_datapoints += len(x_text)
            total_n_correct += n_correct

            all_losses.append(loss)
            pbar.set_description(f"Loss = {loss:.3f}")

            if not accum_grad_over_epoch:
                # if hotflip, autoprompt, etc., grad will be zero
                optim.step()
                optim.zero_grad()

            # Early stopping, check after step
            model_check_early_stop = model.check_early_stop()
            if model_check_early_stop:
                print("model_check_early_stop returned true")
            if (total_n_datapoints > max_n_datapoints) or (total_n_steps > max_n_steps) or model_check_early_stop:
                stopping_early = True
                break

        if stopping_early:
            print(f"Ending epoch {epoch} early...")
        avg_loss = sum(all_losses) / len(all_losses)
        print(f"Epoch {epoch}. average loss = {avg_loss:.3f} / {total_n_correct} / {total_n} correct ({total_n_correct/total_n*100:.2f}%)")

        # save stuff
        for key, val in model.compute_metrics().items():
            r[key].append(val)

        # r['losses'].append(avg_loss)
        if epoch % epoch_save_interval == 0:
            os.makedirs(save_dir, exist_ok=True)
            pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

        model.post_epoch(dataloader=dataloader,
                         possible_answer_mask=possible_answer_mask)

        if accum_grad_over_epoch:
            optim.step()
            optim.zero_grad()

        # Early stopping, check after epoch
        if stopping_early:
            print(
                f"Stopping early after {total_n_steps} steps and {total_n_datapoints} datapoints")
            break

    # Serialize model-specific stuff (prefixes & losses for autoprompt, embeddings for prompt tuning, etc.)
    n_eval = 256
    eval_dset = datasets.Dataset.from_dict(dset[:n_eval])
    eval_dataloader = DataLoader(
        eval_dset, batch_size=batch_size, shuffle=True, drop_last=False)
    r.update(model.serialize(eval_dataloader, possible_answer_mask))
    # r.update(model.serialize())

    # save whether prefixes fit the template
    """
    if "prefixes" in r:
        r["prefixes__check_answer_func"] = list(
            map(check_answer_func, r["prefixes"]))
    """

    r['train_end_time'] = time.time()
    r['train_time_elapsed'] = r['train_end_time'] - r['train_start_time']

    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

    return r


def eval_model_with_set_prefix(
    dataloader: DataLoader,
    model: PrefixModel,
) -> Tuple[float, float]:
    """
    Evaluates a model based on set prefix. May be called multiple times with different prefixes

    Params
    ------
    r: dict
        dictionary of things to save

    Returns: Tuple[float, float]
        average loss, accuracy per sample over eval dataset
    """
    pbar = tqdm(enumerate(dataloader), total=len(dataloader),
                desc='evaluating data', colour='red', leave=False)
    total_loss = 0.0
    total_n = 0
    total_n_correct = 0
    for idx, batch in pbar:
        x_text, y_text = model.prepare_batch(batch=batch)

        tok = functools.partial(
            model.tokenizer, return_tensors='pt', padding='longest')
        x_tokenized = tok(x_text).to(device)
        y_tokenized = tok(y_text).to(device)
        # full_text_tokenized = tok(batch['text']).to(device)

        with torch.no_grad():
            _input_ids, loss, n_correct = model._compute_loss_with_set_prefix(
                original_input_ids=x_tokenized.input_ids,
                next_token_ids=y_tokenized.input_ids,
                possible_answer_mask=None,  # TODO: implement eval verbalizer
                prefix_ids=None,
            )

        total_loss += loss.item()
        total_n += len(x_text)
        total_n_correct += n_correct

        pbar.set_description(
            f"Acc = {total_n_correct}/{total_n} {(total_n_correct/total_n*100):.2f}%")

    return (total_loss / total_n), (total_n_correct / total_n)


def eval_model(
    r: Dict[str, List],
    dset: datasets.Dataset,
    model: PrefixModel,
    batch_size: int = 500,
    save_dir: str = 'results',
):
    """
    Evaluates a model based on the learned prefix(es).

    Params
    ------
    r: dict
        dictionary of things to save
    """
    r["test_start_time"] = time.time()
    model.eval()
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=False, drop_last=False)

    if r["prefixes"]:
        # if we specified multiple prefixes (autoprompt or iprompt), let's evaluate them all!
        for prefix_ids in tqdm(r["prefix_ids"], desc="evaluating prefixes"):
            model._set_prefix_ids(new_ids=torch.tensor(prefix_ids).to(device))

            loss, acc = eval_model_with_set_prefix(dataloader, model)

            r["prefix_test_loss"].append(loss)
            r["prefix_test_acc"].append(acc)
        r["num_prefixes_used_for_test"] = len(r["prefixes"])

    else:
        # otherwise, there's just one prefix (like for prompt tuning) so just run single eval loop.
        loss, acc = eval_model_with_set_prefix(dataloader, model)
        r["prefix_test_acc"] = loss
        r["prefix_test_loss"] = acc
        r["num_prefixes_used_for_test"] = 1

    r["test_end_time"] = time.time()
    r["test_time_elapsed"] = r["test_end_time"] - r["test_start_time"]
    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))
    return r


def explain_dataset_iprompt(
    input_strings: List[str],
    output_strings: List[str],
    checkpoint: str='EleutherAI/gpt-j-6B',
    generation_checkpoint: str = '',
    num_learned_tokens=1,
    save_dir: str = './results',
    lr: float = 0.01,
    pop_size: int = 8,
    pop_criterion: str = 'loss',
    pop_topk_strategy: str = 'different_start_token',
    num_mutations: int = 4,
    prefix_before_input: bool = True,
    do_final_reranking: bool = False,
    num_random_generations: int = 4,
    generation_repetition_penalty: float = 2.0,
    generation_temp: float = 1.0,
    generation_top_p: float = 1.0,
    early_stopping_steps: int = -1,
    llm_float16=False,
    gamma: float = 0.0,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    preprefix: str = '',
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    model_cls: str = 'iprompt',
    lm: transformers.PreTrainedModel = None,
    llm_candidate_regeneration_prompt_start: str = 'Data:',
    llm_candidate_regeneration_prompt_end: str = 'Prompt:',
    verbose: int = 0,  # verbosity level (0 for minimal)
    seed: int = 42,
    llm_api: str = "",
) -> Tuple[List[str], Dict]:
    """Explain the relationship between the input strings and the output strings

    Parameters
    ----------
    input_strings: List[str]
        list of input strings (e.g. "2 + 2")
    output_strings: List[str]
        list of output strings (e.g. "4")
    checkpoint: str
        name of model checkpoint to prompt (e.g. EleutherAI/gpt-j-6B)
    generation_checkpoint: str
        name of model to generate prompts, *only if if different from checkpoint
        used for prompting*. defaults to '' (same model for both).
    prefix_before_input: bool
        whether to prompt the LLM with the prefix before or after the input data
    do_final_reranking: bool
        optionally rerank top prefixes using a single batch. helps prevent
        noisy prefixes from being top at the end, especially when run over a
        small number of iterations or with small batch size.
    generation_temp: float
        temperature for sampling from LLM (defaults to 1.0)
    generation_top_p: float
        p for sampling from LLM, if using top-p sampling (defaults to 1.0, no sampling)
    num_learned_tokens: int
        number of tokens to learn in prompt
    save_dir: str
        directory to save results
    lr: float
        learning rate for prompt tuning
    pop_size: int
        number of prompt candidates to evaluate for each iteration of iprompt
    pop_criterion: str
        criterion for getting top prefixes from prefix pool, in ['loss', 'acc']
    pop_topk_strategy: str
        strategy for getting new prefixes from prefix pool, in ['different_start_token', 'all']
    num_mutations: int
        number of mutations to apply to each prompt candidate
    lm: transformers.PreTrainedModel
        pre-loaded model (overrides checkpoint)
    max_n_data_points: int
        maximum number of data points to use for training
        if n_shots > 1, this many data_points are created by recombining n_shots number of examples


    Returns
    -------
    best_prompts
        List of the best found prompts
    metadata_dict
        Dictionary of metadata from fitting
    """
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    if not prefix_before_input:
        tokenizer.truncation_side = 'left'
    tokenizer.eos_token = tokenizer.eos_token or 0
    tokenizer.pad_token = tokenizer.eos_token

    # load the model (unless already loaded)
    def load_lm(checkpoint, tokenizer, llm_float16):
        if llm_float16:
            if checkpoint == "EleutherAI/gpt-j-6B":
                lm = AutoModelForCausalLM.from_pretrained(
                    checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id,
                    revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
                )
            else:
                # (only certain models are pre-float16ed)
                print(f"trying to convert {checkpoint} to float16...")
                lm = transformers.AutoModelForCausalLM.from_pretrained(
                    checkpoint, torch_dtype=torch.float16
                )
                lm = lm.half()
        else:
            lm = AutoModelForCausalLM.from_pretrained(
                checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id
            )
        return lm
    if lm is None:
        lm = load_lm(checkpoint, tokenizer, llm_float16)
        
    loss_func = PrefixLoss(gamma=gamma, tokenizer=tokenizer)

    if model_cls == 'iprompt':
        model = iPrompt(
            loss_func=loss_func,
            model=lm,
            tokenizer=tokenizer,
            preprefix_str=preprefix,
            pop_size=pop_size,
            pop_criterion=pop_criterion,
            pop_topk_strategy=pop_topk_strategy,
            num_mutations=num_mutations,
            prefix_before_input=prefix_before_input,
            do_final_reranking=do_final_reranking,
            num_random_generations=num_random_generations,
            generation_repetition_penalty=generation_repetition_penalty,
            generation_temp=generation_temp,
            generation_top_p=generation_top_p,
            early_stopping_steps=early_stopping_steps,
            num_learned_tokens=num_learned_tokens,
            max_length=max_length,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            verbose=verbose,
            llm_float16=llm_float16,
            generation_checkpoint=generation_checkpoint,
            llm_candidate_regeneration_prompt_start=llm_candidate_regeneration_prompt_start,
            llm_candidate_regeneration_prompt_end=llm_candidate_regeneration_prompt_end,
        )
    else:
        pass
        """
        model = model_cls_dict[model_cls](
            args=args,
            loss_func=loss_func, model=lm, tokenizer=tokenizer, preprefix=preprefix
        )
        """
    
    iprompt_local = len(llm_api) == 0
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    r = defaultdict(list)
    if iprompt_local:
        r = run_iprompt_local(
            r=r,
            input_strs=input_strings,
            output_strs=output_strings,
            model=model,
            tokenizer=tokenizer,
            save_dir=save_dir,
            lr=lr,
            batch_size=batch_size,
            max_length=max_length,
            mask_possible_answers=mask_possible_answers,
            n_epochs=n_epochs,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            accum_grad_over_epoch=accum_grad_over_epoch,
            max_n_datapoints=max_n_datapoints,
            max_n_steps=max_n_steps,
            epoch_save_interval=epoch_save_interval,
            verbose=verbose,
        )
    else:
        r = run_iprompt_api(
            r=r,
            input_strs=input_strings,
            output_strs=output_strings,
            model=model,
            tokenizer=tokenizer,
            save_dir=save_dir,
            lr=lr,
            batch_size=batch_size,
            max_length=max_length,
            mask_possible_answers=mask_possible_answers,
            n_epochs=n_epochs,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            accum_grad_over_epoch=accum_grad_over_epoch,
            max_n_datapoints=max_n_datapoints,
            max_n_steps=max_n_steps,
            epoch_save_interval=epoch_save_interval,
            verbose=verbose,
            llm_api=llm_api,
        )
    model = model.cpu()
    return r['prefixes'], r

    # r = eval_model(args=args, r=r, dset=Dataset.from_dict(dset_test[:128]), model=model, tokenizer=tokenizer)


# python api.py --task_name_list add_two --model_cls iprompt --num_learned_tokens 3 --max_dset_size 100 --max_n_datapoints 100 --early_stopping_steps 5 --max_digit 10 --train_split_frac 0.75 --single_shot_loss 1 --save_dir /home/chansingh/tmp/iprompt --checkpoint EleutherAI/gpt-j-6B --batch_size 64 --n_epochs 20
# python api.py --task_name_list add_two --model_cls iprompt --num_learned_tokens 3 --max_dset_size 5000 --max_n_datapoints 5000 --early_stopping_steps 25 --max_digit 10 --train_split_frac 0.75 --single_shot_loss 1 --save_dir /home/chansingh/tmp/iprompt --checkpoint EleutherAI/gpt-j-6B --batch_size 64 --float16 1
# python api.py --n_epochs 1 --max_n_steps 3 --max_n_datapoints 10
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--model_cls', type=str,
                        choices=model_cls_dict.keys(),
                        default='iprompt',
                        help='model type to use for training')
    parser.add_argument('--batch_size', type=int, default=500,
                        help='batch size for training')
    parser.add_argument('--seed', type=int, default=1,
                        help='random seed')
    parser.add_argument('--n_epochs', type=int, default=2,
                        help='number of epochs for training')
    parser.add_argument('--max_n_steps', type=int, default=10**10,
                        help='max number of steps for training')
    parser.add_argument('--max_n_datapoints', type=int, default=20,  # 10**10,
                        help='max number of datapoints for training')
    parser.add_argument('--train_split_frac', type=float,
                        default=None, help='fraction for train-test split if desired')
    parser.add_argument('--max_dset_size', type=int,
                        default=10**4, help='maximum allowable dataset size')
    parser.add_argument('--early_stopping_steps', type=int, default=-1,
                        help='if > 0, number of steps until stopping early after no improvement')
    parser.add_argument('--max_digit', type=int, default=10,
                        help='maximum value of each digit in summand')
    parser.add_argument('--template_num_init_string', type=int, default=0,
                        help='the number of the manually-specified prefix to be initialize with')
    parser.add_argument('--template_num_task_phrasing', type=int, default=0,
                        help='the number of the manual template for any given task (number of options varies with task')
    parser.add_argument('--save_dir', type=str, default='results',
                        help='directory for saving')
    parser.add_argument('--epoch_save_interval', type=int, default=1,
                        help='interval to save results')
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('--gamma', type=float, default=0.0,
                        help='hparam: weight for language modeling loss')
    parser.add_argument('--task_name', type=str, default='add_two',
                        choices=(data.TASKS.keys() - {'SUFFIX'}),
                        help='name of task')
    parser.add_argument('--task_name_list', nargs="*", default=None,
                        help='names of tasks as list; alternative to passing task_name')
    parser.add_argument('--n_shots', type=int, default=1,
                        help='number of shots in the prompt')
    parser.add_argument('--autoprompt_init_strategy', type=str, default='the',
                        choices=('random', 'the'), help='initialization strategy for discrete tokens')
    parser.add_argument('--max_length', type=int, default=128,
                        help='maximum length for inputs')
    parser.add_argument('--single_shot_loss', type=int, default=0,
                        help='if n_shots==0, load multiple shots but only use one compute loss')
    parser.add_argument('--mask_possible_answers', type=int, default=0,
                        help='only compute loss over possible answer tokens')
    parser.add_argument('--hotflip_num_candidates', type=int, default=10,
                        help='number of candidates to rerank, for hotflip')
    parser.add_argument('--accum_grad_over_epoch', type=int, default=0, choices=(0, 1),
                        help='should we clear gradients after a batch, or only at the end of the epoch?')
    parser.add_argument('--num_learned_tokens', type=int, default=1,
                        help='number of learned prefix tokens (for gumbel, hotflip, autoprompt, prompt-tuning)')
    parser.add_argument('--use_preprefix', type=int, default=1, choices=(0, 1),
                        help='whether to use a template pre-prefix')
    parser.add_argument('--iprompt_preprefix_str', type=str, default='',
                        help='Text like "Output the number that" or "Answer F/M if"...'
                        )
    parser.add_argument('--iprompt_pop_size', type=int, default=8,)
    parser.add_argument('--iprompt_num_mutations', type=int, default=4)
    parser.add_argument('--iprompt_num_random_generations',
                        type=int, default=4)
    parser.add_argument('--iprompt_generation_repetition_penalty', type=float, default=2.0,
                        help='repetition penalty for iprompt generations')
    parser.add_argument('--llm_float16', '--float16', '--parsimonious', type=int, default=0, choices=(0, 1),
                        help='if true, loads LLM in fp16 and at low-ram')
    parser.add_argument('--checkpoint', type=str, default="gpt2",
                        choices=(
                            ############################
                            "EleutherAI/gpt-neo-125M",
                            "EleutherAI/gpt-neo-1.3B",
                            "EleutherAI/gpt-neo-2.7B",
                            ############################
                            "EleutherAI/gpt-j-6B",
                            ############################
                            "EleutherAI/gpt-neox-20b",
                            ############################
                            "gpt2",        # 117M params
                            "gpt2-medium",  # 355M params
                            "gpt2-large",  # 774M params
                            "gpt2-xl",     # 1.5B params
                            ############################
                        ),
                        help='model checkpoint to use'
                        )

    args = parser.parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    transformers.set_seed(args.seed)

    args.use_generic_query = 0

    if (args.mask_possible_answers) and (args.train_split_frac is not None):
        print("Warning: mask possible answers not supported for eval")

    # iterate over tasks
    if args.task_name_list is not None:
        logging.info('using task_name_list ' + str(args.task_name_list))
    else:
        args.task_name_list = [args.task_name]
    for task_idx, task_name in enumerate(args.task_name_list):
        print(f'*** Executing task {task_idx+1}/{len(args.task_name_list)}')
        # actually set the task
        args.task_name = task_name

        r = defaultdict(list)
        r.update(vars(args))
        logger = logging.getLogger()
        logging.basicConfig(level=logging.INFO)

        logger.info('loading data and model...')
        # set up saving
        save_dir_unique = datetime.now().strftime("%b_%d_%H_%M_") + \
            ''.join(random.choices(string.ascii_lowercase, k=12))
        save_dir = os.path.join(args.save_dir, save_dir_unique)
        logging.info('saving to ' + save_dir)
        args.save_dir_unique = save_dir

        # get data
        # import this here so it's not needed for the package....
        import iprompt.data as data
        dset, _, _ = data.get_data(
            task_name=args.task_name, n_shots=args.n_shots, train_split_frac=args.train_split_frac, max_dset_size=args.max_dset_size,
            template_num_task_phrasing=args.template_num_task_phrasing, max_digit=args.max_digit
        )
        # pd.DataFrame.from_dict({
        #     'input_strings': dset['input'],
        #     'output_strings': [repr(x) for x in dset['output']],
        # }).to_csv('add_two.csv', index=False)

        prompts, meta = explain_dataset_iprompt(
            input_strings=dset['input'],
            output_strings=dset['output'],
            checkpoint=args.checkpoint,
            save_dir=args.save_dir,
            lr=args.lr,
            pop_size=args.iprompt_pop_size,
            num_mutations=args.iprompt_num_mutations,
            num_random_generations=args.iprompt_num_random_generations,
            generation_repetition_penalty=args.iprompt_generation_repetition_penalty,
            early_stopping_steps=args.early_stopping_steps,
            num_learned_tokens=args.num_learned_tokens,
            llm_float16=args.llm_float16,
            gamma=args.gamma,
            batch_size=args.batch_size,
            max_length=args.max_length,
            n_epochs=args.n_epochs,
            n_shots=args.n_shots,
            single_shot_loss=args.single_shot_loss,
            accum_grad_over_epoch=args.accum_grad_over_epoch,
            max_n_datapoints=args.max_n_datapoints,
            max_n_steps=args.max_n_steps,
            epoch_save_interval=args.epoch_save_interval,
            mask_possible_answers=args.mask_possible_answers,
            model_cls=args.model_cls,
        )
        print('prompts', prompts)
        print('\nmeta', meta)

Functions

def eval_model(r: Dict[str, List], dset: datasets.arrow_dataset.Dataset, model: PrefixModel, batch_size: int = 500, save_dir: str = 'results')

Evaluates a model based on the learned prefix(es).

Params

r: dict dictionary of things to save

Expand source code
def eval_model(
    r: Dict[str, List],
    dset: datasets.Dataset,
    model: PrefixModel,
    batch_size: int = 500,
    save_dir: str = 'results',
):
    """
    Evaluates a model based on the learned prefix(es).

    Params
    ------
    r: dict
        dictionary of things to save
    """
    r["test_start_time"] = time.time()
    model.eval()
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=False, drop_last=False)

    if r["prefixes"]:
        # if we specified multiple prefixes (autoprompt or iprompt), let's evaluate them all!
        for prefix_ids in tqdm(r["prefix_ids"], desc="evaluating prefixes"):
            model._set_prefix_ids(new_ids=torch.tensor(prefix_ids).to(device))

            loss, acc = eval_model_with_set_prefix(dataloader, model)

            r["prefix_test_loss"].append(loss)
            r["prefix_test_acc"].append(acc)
        r["num_prefixes_used_for_test"] = len(r["prefixes"])

    else:
        # otherwise, there's just one prefix (like for prompt tuning) so just run single eval loop.
        loss, acc = eval_model_with_set_prefix(dataloader, model)
        r["prefix_test_acc"] = loss
        r["prefix_test_loss"] = acc
        r["num_prefixes_used_for_test"] = 1

    r["test_end_time"] = time.time()
    r["test_time_elapsed"] = r["test_end_time"] - r["test_start_time"]
    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))
    return r
def eval_model_with_set_prefix(dataloader: torch.utils.data.dataloader.DataLoader, model: PrefixModel) ‑> Tuple[float, float]

Evaluates a model based on set prefix. May be called multiple times with different prefixes

Params

r: dict dictionary of things to save

Returns: Tuple[float, float] average loss, accuracy per sample over eval dataset

Expand source code
def eval_model_with_set_prefix(
    dataloader: DataLoader,
    model: PrefixModel,
) -> Tuple[float, float]:
    """
    Evaluates a model based on set prefix. May be called multiple times with different prefixes

    Params
    ------
    r: dict
        dictionary of things to save

    Returns: Tuple[float, float]
        average loss, accuracy per sample over eval dataset
    """
    pbar = tqdm(enumerate(dataloader), total=len(dataloader),
                desc='evaluating data', colour='red', leave=False)
    total_loss = 0.0
    total_n = 0
    total_n_correct = 0
    for idx, batch in pbar:
        x_text, y_text = model.prepare_batch(batch=batch)

        tok = functools.partial(
            model.tokenizer, return_tensors='pt', padding='longest')
        x_tokenized = tok(x_text).to(device)
        y_tokenized = tok(y_text).to(device)
        # full_text_tokenized = tok(batch['text']).to(device)

        with torch.no_grad():
            _input_ids, loss, n_correct = model._compute_loss_with_set_prefix(
                original_input_ids=x_tokenized.input_ids,
                next_token_ids=y_tokenized.input_ids,
                possible_answer_mask=None,  # TODO: implement eval verbalizer
                prefix_ids=None,
            )

        total_loss += loss.item()
        total_n += len(x_text)
        total_n_correct += n_correct

        pbar.set_description(
            f"Acc = {total_n_correct}/{total_n} {(total_n_correct/total_n*100):.2f}%")

    return (total_loss / total_n), (total_n_correct / total_n)
def explain_dataset_iprompt(input_strings: List[str], output_strings: List[str], checkpoint: str = 'EleutherAI/gpt-j-6B', generation_checkpoint: str = '', num_learned_tokens=1, save_dir: str = './results', lr: float = 0.01, pop_size: int = 8, pop_criterion: str = 'loss', pop_topk_strategy: str = 'different_start_token', num_mutations: int = 4, prefix_before_input: bool = True, do_final_reranking: bool = False, num_random_generations: int = 4, generation_repetition_penalty: float = 2.0, generation_temp: float = 1.0, generation_top_p: float = 1.0, early_stopping_steps: int = -1, llm_float16=False, gamma: float = 0.0, batch_size: int = 64, max_length: int = 128, n_epochs: int = 100, n_shots: int = 1, preprefix: str = '', single_shot_loss: bool = True, accum_grad_over_epoch: bool = False, max_n_datapoints: int = 10000, max_n_steps: int = 10000, epoch_save_interval: int = 1, mask_possible_answers: bool = False, model_cls: str = 'iprompt', lm: transformers.modeling_utils.PreTrainedModel = None, llm_candidate_regeneration_prompt_start: str = 'Data:', llm_candidate_regeneration_prompt_end: str = 'Prompt:', verbose: int = 0, seed: int = 42, llm_api: str = '') ‑> Tuple[List[str], Dict]

Explain the relationship between the input strings and the output strings

Parameters

input_strings : List[str]
list of input strings (e.g. "2 + 2")
output_strings : List[str]
list of output strings (e.g. "4")
checkpoint : str
name of model checkpoint to prompt (e.g. EleutherAI/gpt-j-6B)
generation_checkpoint : str
name of model to generate prompts, only if if different from checkpoint used for prompting. defaults to '' (same model for both).
prefix_before_input : bool
whether to prompt the LLM with the prefix before or after the input data
do_final_reranking : bool
optionally rerank top prefixes using a single batch. helps prevent noisy prefixes from being top at the end, especially when run over a small number of iterations or with small batch size.
generation_temp : float
temperature for sampling from LLM (defaults to 1.0)
generation_top_p : float
p for sampling from LLM, if using top-p sampling (defaults to 1.0, no sampling)
num_learned_tokens : int
number of tokens to learn in prompt
save_dir : str
directory to save results
lr : float
learning rate for prompt tuning
pop_size : int
number of prompt candidates to evaluate for each iteration of iprompt
pop_criterion : str
criterion for getting top prefixes from prefix pool, in ['loss', 'acc']
pop_topk_strategy : str
strategy for getting new prefixes from prefix pool, in ['different_start_token', 'all']
num_mutations : int
number of mutations to apply to each prompt candidate
lm : transformers.PreTrainedModel
pre-loaded model (overrides checkpoint)
max_n_data_points : int
maximum number of data points to use for training if n_shots > 1, this many data_points are created by recombining n_shots number of examples

Returns

best_prompts
List of the best found prompts
metadata_dict
Dictionary of metadata from fitting
Expand source code
def explain_dataset_iprompt(
    input_strings: List[str],
    output_strings: List[str],
    checkpoint: str='EleutherAI/gpt-j-6B',
    generation_checkpoint: str = '',
    num_learned_tokens=1,
    save_dir: str = './results',
    lr: float = 0.01,
    pop_size: int = 8,
    pop_criterion: str = 'loss',
    pop_topk_strategy: str = 'different_start_token',
    num_mutations: int = 4,
    prefix_before_input: bool = True,
    do_final_reranking: bool = False,
    num_random_generations: int = 4,
    generation_repetition_penalty: float = 2.0,
    generation_temp: float = 1.0,
    generation_top_p: float = 1.0,
    early_stopping_steps: int = -1,
    llm_float16=False,
    gamma: float = 0.0,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    preprefix: str = '',
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    model_cls: str = 'iprompt',
    lm: transformers.PreTrainedModel = None,
    llm_candidate_regeneration_prompt_start: str = 'Data:',
    llm_candidate_regeneration_prompt_end: str = 'Prompt:',
    verbose: int = 0,  # verbosity level (0 for minimal)
    seed: int = 42,
    llm_api: str = "",
) -> Tuple[List[str], Dict]:
    """Explain the relationship between the input strings and the output strings

    Parameters
    ----------
    input_strings: List[str]
        list of input strings (e.g. "2 + 2")
    output_strings: List[str]
        list of output strings (e.g. "4")
    checkpoint: str
        name of model checkpoint to prompt (e.g. EleutherAI/gpt-j-6B)
    generation_checkpoint: str
        name of model to generate prompts, *only if if different from checkpoint
        used for prompting*. defaults to '' (same model for both).
    prefix_before_input: bool
        whether to prompt the LLM with the prefix before or after the input data
    do_final_reranking: bool
        optionally rerank top prefixes using a single batch. helps prevent
        noisy prefixes from being top at the end, especially when run over a
        small number of iterations or with small batch size.
    generation_temp: float
        temperature for sampling from LLM (defaults to 1.0)
    generation_top_p: float
        p for sampling from LLM, if using top-p sampling (defaults to 1.0, no sampling)
    num_learned_tokens: int
        number of tokens to learn in prompt
    save_dir: str
        directory to save results
    lr: float
        learning rate for prompt tuning
    pop_size: int
        number of prompt candidates to evaluate for each iteration of iprompt
    pop_criterion: str
        criterion for getting top prefixes from prefix pool, in ['loss', 'acc']
    pop_topk_strategy: str
        strategy for getting new prefixes from prefix pool, in ['different_start_token', 'all']
    num_mutations: int
        number of mutations to apply to each prompt candidate
    lm: transformers.PreTrainedModel
        pre-loaded model (overrides checkpoint)
    max_n_data_points: int
        maximum number of data points to use for training
        if n_shots > 1, this many data_points are created by recombining n_shots number of examples


    Returns
    -------
    best_prompts
        List of the best found prompts
    metadata_dict
        Dictionary of metadata from fitting
    """
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    if not prefix_before_input:
        tokenizer.truncation_side = 'left'
    tokenizer.eos_token = tokenizer.eos_token or 0
    tokenizer.pad_token = tokenizer.eos_token

    # load the model (unless already loaded)
    def load_lm(checkpoint, tokenizer, llm_float16):
        if llm_float16:
            if checkpoint == "EleutherAI/gpt-j-6B":
                lm = AutoModelForCausalLM.from_pretrained(
                    checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id,
                    revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
                )
            else:
                # (only certain models are pre-float16ed)
                print(f"trying to convert {checkpoint} to float16...")
                lm = transformers.AutoModelForCausalLM.from_pretrained(
                    checkpoint, torch_dtype=torch.float16
                )
                lm = lm.half()
        else:
            lm = AutoModelForCausalLM.from_pretrained(
                checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id
            )
        return lm
    if lm is None:
        lm = load_lm(checkpoint, tokenizer, llm_float16)
        
    loss_func = PrefixLoss(gamma=gamma, tokenizer=tokenizer)

    if model_cls == 'iprompt':
        model = iPrompt(
            loss_func=loss_func,
            model=lm,
            tokenizer=tokenizer,
            preprefix_str=preprefix,
            pop_size=pop_size,
            pop_criterion=pop_criterion,
            pop_topk_strategy=pop_topk_strategy,
            num_mutations=num_mutations,
            prefix_before_input=prefix_before_input,
            do_final_reranking=do_final_reranking,
            num_random_generations=num_random_generations,
            generation_repetition_penalty=generation_repetition_penalty,
            generation_temp=generation_temp,
            generation_top_p=generation_top_p,
            early_stopping_steps=early_stopping_steps,
            num_learned_tokens=num_learned_tokens,
            max_length=max_length,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            verbose=verbose,
            llm_float16=llm_float16,
            generation_checkpoint=generation_checkpoint,
            llm_candidate_regeneration_prompt_start=llm_candidate_regeneration_prompt_start,
            llm_candidate_regeneration_prompt_end=llm_candidate_regeneration_prompt_end,
        )
    else:
        pass
        """
        model = model_cls_dict[model_cls](
            args=args,
            loss_func=loss_func, model=lm, tokenizer=tokenizer, preprefix=preprefix
        )
        """
    
    iprompt_local = len(llm_api) == 0
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    r = defaultdict(list)
    if iprompt_local:
        r = run_iprompt_local(
            r=r,
            input_strs=input_strings,
            output_strs=output_strings,
            model=model,
            tokenizer=tokenizer,
            save_dir=save_dir,
            lr=lr,
            batch_size=batch_size,
            max_length=max_length,
            mask_possible_answers=mask_possible_answers,
            n_epochs=n_epochs,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            accum_grad_over_epoch=accum_grad_over_epoch,
            max_n_datapoints=max_n_datapoints,
            max_n_steps=max_n_steps,
            epoch_save_interval=epoch_save_interval,
            verbose=verbose,
        )
    else:
        r = run_iprompt_api(
            r=r,
            input_strs=input_strings,
            output_strs=output_strings,
            model=model,
            tokenizer=tokenizer,
            save_dir=save_dir,
            lr=lr,
            batch_size=batch_size,
            max_length=max_length,
            mask_possible_answers=mask_possible_answers,
            n_epochs=n_epochs,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            accum_grad_over_epoch=accum_grad_over_epoch,
            max_n_datapoints=max_n_datapoints,
            max_n_steps=max_n_steps,
            epoch_save_interval=epoch_save_interval,
            verbose=verbose,
            llm_api=llm_api,
        )
    model = model.cpu()
    return r['prefixes'], r

    # r = eval_model(args=args, r=r, dset=Dataset.from_dict(dset_test[:128]), model=model, tokenizer=tokenizer)
def get_prompts_api(data: List[str], llm: Callable, prompt_template: str)
Expand source code
def get_prompts_api(
        data: List[str], 
        llm: Callable,
        prompt_template: str, 
    ):
    data_str = random.choice(data)
    prompt = prompt_template(data=data_str).strip()
    answer = llm(prompt, max_new_tokens=24)
    return [answer]
def run_iprompt_api(r: Dict[str, List], input_strs: List[str], output_strs: List[str], model: PrefixModel, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, llm_api: str, save_dir: str = 'results', lr: float = 0.0001, batch_size: int = 64, max_length: int = 128, n_epochs: int = 100, n_shots: int = 1, single_shot_loss: bool = True, accum_grad_over_epoch: bool = False, max_n_datapoints: int = 10000, max_n_steps: int = 10000, epoch_save_interval: int = 1, mask_possible_answers: bool = False, verbose: int = 0)

Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

Params

r: dict dictionary of things to save

Expand source code
def run_iprompt_api(
    r: Dict[str, List],
    input_strs: List[str],
    output_strs: List[str],
    model: PrefixModel,
    tokenizer: transformers.PreTrainedTokenizer,
    llm_api: str,
    save_dir: str = 'results',
    lr: float = 1e-4,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    verbose: int = 0,
):
    """
    Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

    Params
    ------
    r: dict
        dictionary of things to save
    """


    # remove periods and newlines from the output so we actually use the tokens
    # for the reranking step in iPrompt
    output_strs = [s.rstrip().rstrip('.') for s in output_strs]

    r['train_start_time'] = time.time()
    model.train()

    logging.info("beginning iPrompt with n_shots = %d", n_shots)

    assert len(input_strs) == len(
        output_strs), "input and output must be same length to create input-output pairs"
    text_strs = list(map(lambda t: f'{t[0]}{t[1]}.', zip(input_strs, output_strs)))
    df = pd.DataFrame.from_dict({
        'input': input_strs,
        'output': output_strs,
        'text': text_strs,
    })
    if n_shots > 1:
        d2 = defaultdict(list)
        for i in range(max_n_datapoints):
            all_shots = df.sample(n=n_shots, replace=False)
            d2['text'].append('\n\n'.join(all_shots['text'].values))
            #
            last_input = all_shots.tail(n=1)['input'].values[0]
            d2['input'].append(
                ''.join(all_shots['text'].values[:-1]) + last_input)
            d2['last_input'].append(last_input)
            #
            last_output = all_shots.tail(n=1)['output'].values[0]
            d2['output'].append(last_output)
            #
        df = pd.DataFrame.from_dict(d2)
    
    # shuffle rows
    if max_n_datapoints < len(df):
        df = df.sample(n=max_n_datapoints, replace=False)
    dset = datasets.Dataset.from_pandas(df)
    dset.shuffle()
    print(f'iPrompt got {len(dset)} datapoints, now loading model...')

    model = model.to(device)
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=True, drop_last=False)

    prompt_template = "{prompt_start}\n\n{data}\n\n{prompt_end}"
    prompt_template = functools.partial(
        prompt_template.format,
        prompt_start=model.llm_candidate_regeneration_prompt_start,
        prompt_end=model.llm_candidate_regeneration_prompt_end,
    )
    
    prompts = []
    # "gpt-3.5-turbo", "text-curie-001"
    llm = get_llm( 
        checkpoint=llm_api, role="user")
    stopping_early = False
    total_n = 0
    total_n_steps = 0
    total_n_datapoints = 0
    for epoch in range(n_epochs):       
        print(f'Beginning epoch {epoch}')
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for idx, batch in pbar:
            total_n_steps += 1
            if (n_shots > 1) and (single_shot_loss):
                batch['input'] = batch['last_input']
            x_text, y_text = model.prepare_batch(batch=batch)

            tok = functools.partial(
                model.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=max_length)
            text_tokenized = tok(batch['text']).to(device)
            text_detokenized = model.tokenizer.batch_decode(
                text_tokenized['input_ids'], 
                skip_special_tokens=True,
            )
            
            prompts.extend(
                get_prompts_api(
                    data=text_detokenized, 
                    llm=llm, 
                    prompt_template=prompt_template, 
                )
            )

            total_n += len(x_text)
            total_n_datapoints += len(x_text)
            if (total_n_datapoints > max_n_datapoints) or (total_n_steps > max_n_steps):
                stopping_early = True
                break

        if stopping_early:
            print(f"Ending epoch {epoch} early...")

        # save stuff
        for key, val in model.compute_metrics().items():
            r[key].append(val)

        # r['losses'].append(avg_loss)
        if epoch % epoch_save_interval == 0:
            os.makedirs(save_dir, exist_ok=True)
            pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

        # Early stopping, check after epoch
        if stopping_early:
            print(
                f"Stopping early after {total_n_steps} steps and {total_n_datapoints} datapoints")
            break

    
    # 
    #   Evaluate model on prefixes
    # 

    # Compute loss only over possible answers to make task easier
    possible_answer_ids = []
    for batch in dataloader:
        y_text = [answer for answer in batch['output']]
        y_tokenized = tokenizer(y_text, return_tensors='pt', padding='longest')
        # only test on the single next token
        true_next_token_ids = y_tokenized['input_ids'][:, 0]
        possible_answer_ids.extend(true_next_token_ids.tolist())
    
    possible_answer_ids = torch.tensor(possible_answer_ids)
    vocab_size = len(tokenizer.vocab)
    possible_answer_mask = (
            torch.arange(start=0, end=vocab_size)[:, None]
            ==
            possible_answer_ids[None, :]
        ).any(dim=1).to(device)
    n_eval = 256
    eval_dset = datasets.Dataset.from_dict(dset[:n_eval])
    eval_dataloader = DataLoader(
        eval_dset, batch_size=batch_size, shuffle=True, drop_last=False)   
    all_prefixes = model.tokenizer(
        [f" {prompt.strip()}" for prompt in prompts], 
        truncation=False, 
        padding=False,
    )["input_ids"]
    all_losses, all_accuracies = model.test_prefixes(
        prefixes=all_prefixes,
        eval_dataloader=eval_dataloader,
        possible_answer_mask=possible_answer_mask
    )

    # 
    #   Store prefix info 
    # 
    df = pd.DataFrame(
        zip(*[all_prefixes, all_losses, all_accuracies]),
        columns=['prefix', 'loss', 'accuracy']
    )
    df = df.sort_values(by=['accuracy', 'loss'], ascending=[
                        False, True]).reset_index()
    df = df.sort_values(by='accuracy', ascending=False).reset_index()

    df['prefix_str'] = df['prefix'].map(
        functools.partial(model.tokenizer.decode, skip_special_tokens=True)
    )

    print('Final prefixes')
    print(df.head())
    r.update({
            "prefix_ids": df['prefix'].tolist(),
            "prefixes": df['prefix_str'].tolist(),
            "prefix_train_acc": df['accuracy'].tolist(),
            "prefix_train_loss": df['loss'].tolist(),
        })

    r['train_end_time'] = time.time()
    r['train_time_elapsed'] = r['train_end_time'] - r['train_start_time']

    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

    return r
def run_iprompt_local(r: Dict[str, List], input_strs: List[str], output_strs: List[str], model: PrefixModel, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, save_dir: str = 'results', lr: float = 0.0001, batch_size: int = 64, max_length: int = 128, n_epochs: int = 100, n_shots: int = 1, single_shot_loss: bool = True, accum_grad_over_epoch: bool = False, max_n_datapoints: int = 10000, max_n_steps: int = 10000, epoch_save_interval: int = 1, mask_possible_answers: bool = False, verbose: int = 0)

Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

Params

r: dict dictionary of things to save

Expand source code
def run_iprompt_local(
    r: Dict[str, List],
    input_strs: List[str],
    output_strs: List[str],
    model: PrefixModel,
    tokenizer: transformers.PreTrainedTokenizer,
    save_dir: str = 'results',
    lr: float = 1e-4,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    verbose: int = 0,
):
    """
    Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

    Params
    ------
    r: dict
        dictionary of things to save
    """

    # remove periods and newlines from the output so we actually use the tokens
    # for the reranking step in iPrompt
    output_strs = [s.rstrip().rstrip('.') for s in output_strs]

    r['train_start_time'] = time.time()
    model.train()

    assert len(input_strs) == len(
        output_strs), "input and output must be same length to create input-output pairs"
    text_strs = list(map(lambda t: f'{t[0]}{t[1]}.', zip(input_strs, output_strs)))
    df = pd.DataFrame.from_dict({
        'input': input_strs,
        'output': output_strs,
        'text': text_strs,
    })
    if n_shots > 1:
        d2 = defaultdict(list)
        for i in range(max_n_datapoints):
            all_shots = df.sample(n=n_shots, replace=False)
            d2['text'].append('\n\n'.join(all_shots['text'].values))
            #
            last_input = all_shots.tail(n=1)['input'].values[0]
            d2['input'].append(
                ''.join(all_shots['text'].values[:-1]) + last_input)
            d2['last_input'].append(last_input)
            #
            last_output = all_shots.tail(n=1)['output'].values[0]
            d2['output'].append(last_output)
            #
        df = pd.DataFrame.from_dict(d2)
    # shuffle rows
    if max_n_datapoints < len(df):
        df = df.sample(n=max_n_datapoints, replace=False)
    dset = datasets.Dataset.from_pandas(df)
    dset.shuffle()
    print(f'iPrompt got {len(dset)} datapoints, now loading model...')

    model = model.to(device)
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=True, drop_last=False)

    # optimizer
    optim = torch.optim.AdamW(model.trainable_params, lr=lr)

    assert model.training

    # Compute loss only over possible answers to make task easier
    possible_answer_ids = []
    for batch in dataloader:
        y_text = [answer for answer in batch['output']]
        y_tokenized = tokenizer(y_text, return_tensors='pt', padding='longest')
        # only test on the single next token
        true_next_token_ids = y_tokenized['input_ids'][:, 0]
        possible_answer_ids.extend(true_next_token_ids.tolist())

    possible_answer_ids = torch.tensor(possible_answer_ids)
    num_unique_answers = len(set(possible_answer_ids.tolist()))
    assert num_unique_answers > 0, "need multiple answers for multiple choice"
    random_acc = 1 / num_unique_answers * 100.0
    majority_count = (
        possible_answer_ids[:, None] == possible_answer_ids[None, :]).sum(dim=1).max()
    majority_acc = majority_count * 100.0 / len(possible_answer_ids)
    print(
        f"Training with {num_unique_answers} possible answers / random acc {random_acc:.1f}% / majority acc {majority_acc:.1f}%")

    vocab_size = len(tokenizer.vocab)

    if mask_possible_answers:
        possible_answer_mask = (
            torch.arange(start=0, end=vocab_size)[:, None]
            ==
            possible_answer_ids[None, :]
        ).any(dim=1).to(device)
    else:
        possible_answer_mask = None

    stopping_early = False
    total_n_steps = 0
    total_n_datapoints = 0
    for epoch in range(n_epochs):
        model.pre_epoch()

        all_losses = []

        total_n = 0
        total_n_correct = 0
        print(f'Beginning epoch {epoch}')
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for idx, batch in pbar:
            total_n_steps += 1
            if (n_shots > 1) and (single_shot_loss):
                batch['input'] = batch['last_input']
            x_text, y_text = model.prepare_batch(batch=batch)

            tok = functools.partial(
                model.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=max_length)
            x_tokenized = tok(x_text).to(device)
            y_tokenized = tok(y_text).to(device)
            full_text_tokenized = tok(batch['text']).to(device)

            loss, n_correct = model.compute_loss_and_call_backward(
                x_tokenized=x_tokenized,
                y_tokenized=y_tokenized,
                possible_answer_mask=possible_answer_mask,
                full_text_tokenized=full_text_tokenized,
            )

            r["all_losses"].append(loss)
            r["all_n_correct"].append(n_correct)

            total_n += len(x_text)
            total_n_datapoints += len(x_text)
            total_n_correct += n_correct

            all_losses.append(loss)
            pbar.set_description(f"Loss = {loss:.3f}")

            if not accum_grad_over_epoch:
                # if hotflip, autoprompt, etc., grad will be zero
                optim.step()
                optim.zero_grad()

            # Early stopping, check after step
            model_check_early_stop = model.check_early_stop()
            if model_check_early_stop:
                print("model_check_early_stop returned true")
            if (total_n_datapoints > max_n_datapoints) or (total_n_steps > max_n_steps) or model_check_early_stop:
                stopping_early = True
                break

        if stopping_early:
            print(f"Ending epoch {epoch} early...")
        avg_loss = sum(all_losses) / len(all_losses)
        print(f"Epoch {epoch}. average loss = {avg_loss:.3f} / {total_n_correct} / {total_n} correct ({total_n_correct/total_n*100:.2f}%)")

        # save stuff
        for key, val in model.compute_metrics().items():
            r[key].append(val)

        # r['losses'].append(avg_loss)
        if epoch % epoch_save_interval == 0:
            os.makedirs(save_dir, exist_ok=True)
            pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

        model.post_epoch(dataloader=dataloader,
                         possible_answer_mask=possible_answer_mask)

        if accum_grad_over_epoch:
            optim.step()
            optim.zero_grad()

        # Early stopping, check after epoch
        if stopping_early:
            print(
                f"Stopping early after {total_n_steps} steps and {total_n_datapoints} datapoints")
            break

    # Serialize model-specific stuff (prefixes & losses for autoprompt, embeddings for prompt tuning, etc.)
    n_eval = 256
    eval_dset = datasets.Dataset.from_dict(dset[:n_eval])
    eval_dataloader = DataLoader(
        eval_dset, batch_size=batch_size, shuffle=True, drop_last=False)
    r.update(model.serialize(eval_dataloader, possible_answer_mask))
    # r.update(model.serialize())

    # save whether prefixes fit the template
    """
    if "prefixes" in r:
        r["prefixes__check_answer_func"] = list(
            map(check_answer_func, r["prefixes"]))
    """

    r['train_end_time'] = time.time()
    r['train_time_elapsed'] = r['train_end_time'] - r['train_start_time']

    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

    return r