Module imodelsx.iprompt.api

Functions

def eval_model(r: Dict[str, List],
dset: datasets.arrow_dataset.Dataset,
model: PrefixModel,
batch_size: int = 500,
save_dir: str = 'results')
Expand source code
def eval_model(
    r: Dict[str, List],
    dset: datasets.Dataset,
    model: PrefixModel,
    batch_size: int = 500,
    save_dir: str = 'results',
):
    """
    Evaluates a model based on the learned prefix(es).

    Params
    ------
    r: dict
        dictionary of things to save
    """
    r["test_start_time"] = time.time()
    model.eval()
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=False, drop_last=False)

    if r["prefixes"]:
        # if we specified multiple prefixes (autoprompt or iprompt), let's evaluate them all!
        for prefix_ids in tqdm(r["prefix_ids"], desc="evaluating prefixes"):
            model._set_prefix_ids(new_ids=torch.tensor(prefix_ids).to(device))

            loss, acc = eval_model_with_set_prefix(dataloader, model)

            r["prefix_test_loss"].append(loss)
            r["prefix_test_acc"].append(acc)
        r["num_prefixes_used_for_test"] = len(r["prefixes"])

    else:
        # otherwise, there's just one prefix (like for prompt tuning) so just run single eval loop.
        loss, acc = eval_model_with_set_prefix(dataloader, model)
        r["prefix_test_acc"] = loss
        r["prefix_test_loss"] = acc
        r["num_prefixes_used_for_test"] = 1

    r["test_end_time"] = time.time()
    r["test_time_elapsed"] = r["test_end_time"] - r["test_start_time"]
    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))
    return r

Evaluates a model based on the learned prefix(es).

Params

r: dict dictionary of things to save

def eval_model_with_set_prefix(dataloader: torch.utils.data.dataloader.DataLoader,
model: PrefixModel) ‑> Tuple[float, float]
Expand source code
def eval_model_with_set_prefix(
    dataloader: DataLoader,
    model: PrefixModel,
) -> Tuple[float, float]:
    """
    Evaluates a model based on set prefix. May be called multiple times with different prefixes

    Params
    ------
    r: dict
        dictionary of things to save

    Returns: Tuple[float, float]
        average loss, accuracy per sample over eval dataset
    """
    pbar = tqdm(enumerate(dataloader), total=len(dataloader),
                desc='evaluating data', colour='red', leave=False)
    total_loss = 0.0
    total_n = 0
    total_n_correct = 0
    for idx, batch in pbar:
        x_text, y_text = model.prepare_batch(batch=batch)

        tok = functools.partial(
            model.tokenizer, return_tensors='pt', padding='longest')
        x_tokenized = tok(x_text).to(device)
        y_tokenized = tok(y_text).to(device)
        # full_text_tokenized = tok(batch['text']).to(device)

        with torch.no_grad():
            _input_ids, loss, n_correct = model._compute_loss_with_set_prefix(
                original_input_ids=x_tokenized.input_ids,
                next_token_ids=y_tokenized.input_ids,
                possible_answer_mask=None,  # TODO: implement eval verbalizer
                prefix_ids=None,
            )

        total_loss += loss.item()
        total_n += len(x_text)
        total_n_correct += n_correct

        pbar.set_description(
            f"Acc = {total_n_correct}/{total_n} {(total_n_correct/total_n*100):.2f}%")

    return (total_loss / total_n), (total_n_correct / total_n)

Evaluates a model based on set prefix. May be called multiple times with different prefixes

Params

r: dict dictionary of things to save

Returns: Tuple[float, float] average loss, accuracy per sample over eval dataset

def explain_dataset_iprompt(input_strings: List[str],
output_strings: List[str],
checkpoint: str = 'EleutherAI/gpt-j-6B',
generation_checkpoint: str = '',
num_learned_tokens=1,
save_dir: str = './results',
lr: float = 0.01,
pop_size: int = 8,
pop_criterion: str = 'loss',
pop_topk_strategy: str = 'different_start_token',
num_mutations: int = 4,
prefix_before_input: bool = True,
do_final_reranking: bool = False,
num_random_generations: int = 4,
generation_repetition_penalty: float = 2.0,
generation_temp: float = 1.0,
generation_top_p: float = 1.0,
early_stopping_steps: int = -1,
llm_float16=False,
gamma: float = 0.0,
batch_size: int = 64,
max_length: int = 128,
n_epochs: int = 100,
n_shots: int = 1,
preprefix: str = '',
single_shot_loss: bool = True,
accum_grad_over_epoch: bool = False,
max_n_datapoints: int = 10000,
max_n_steps: int = 10000,
epoch_save_interval: int = 1,
mask_possible_answers: bool = False,
model_cls: str = 'iprompt',
lm: transformers.modeling_utils.PreTrainedModel = None,
llm_candidate_regeneration_prompt_start: str = 'Data:',
llm_candidate_regeneration_prompt_end: str = 'Prompt:',
verbose: int = 0,
seed: int = 42,
llm_api: str = '') ‑> Tuple[List[str], Dict]
Expand source code
def explain_dataset_iprompt(
    input_strings: List[str],
    output_strings: List[str],
    checkpoint: str='EleutherAI/gpt-j-6B',
    generation_checkpoint: str = '',
    num_learned_tokens=1,
    save_dir: str = './results',
    lr: float = 0.01,
    pop_size: int = 8,
    pop_criterion: str = 'loss',
    pop_topk_strategy: str = 'different_start_token',
    num_mutations: int = 4,
    prefix_before_input: bool = True,
    do_final_reranking: bool = False,
    num_random_generations: int = 4,
    generation_repetition_penalty: float = 2.0,
    generation_temp: float = 1.0,
    generation_top_p: float = 1.0,
    early_stopping_steps: int = -1,
    llm_float16=False,
    gamma: float = 0.0,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    preprefix: str = '',
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    model_cls: str = 'iprompt',
    lm: transformers.PreTrainedModel = None,
    llm_candidate_regeneration_prompt_start: str = 'Data:',
    llm_candidate_regeneration_prompt_end: str = 'Prompt:',
    verbose: int = 0,  # verbosity level (0 for minimal)
    seed: int = 42,
    llm_api: str = "",
) -> Tuple[List[str], Dict]:
    """Explain the relationship between the input strings and the output strings

    Parameters
    ----------
    input_strings: List[str]
        list of input strings (e.g. "2 + 2")
    output_strings: List[str]
        list of output strings (e.g. "4")
    checkpoint: str
        name of model checkpoint to prompt (e.g. EleutherAI/gpt-j-6B)
    generation_checkpoint: str
        name of model to generate prompts, *only if if different from checkpoint
        used for prompting*. defaults to '' (same model for both).
    prefix_before_input: bool
        whether to prompt the LLM with the prefix before or after the input data
    do_final_reranking: bool
        optionally rerank top prefixes using a single batch. helps prevent
        noisy prefixes from being top at the end, especially when run over a
        small number of iterations or with small batch size.
    generation_temp: float
        temperature for sampling from LLM (defaults to 1.0)
    generation_top_p: float
        p for sampling from LLM, if using top-p sampling (defaults to 1.0, no sampling)
    num_learned_tokens: int
        number of tokens to learn in prompt
    save_dir: str
        directory to save results
    lr: float
        learning rate for prompt tuning
    pop_size: int
        number of prompt candidates to evaluate for each iteration of iprompt
    pop_criterion: str
        criterion for getting top prefixes from prefix pool, in ['loss', 'acc']
    pop_topk_strategy: str
        strategy for getting new prefixes from prefix pool, in ['different_start_token', 'all']
    num_mutations: int
        number of mutations to apply to each prompt candidate
    lm: transformers.PreTrainedModel
        pre-loaded model (overrides checkpoint)
    max_n_data_points: int
        maximum number of data points to use for training
        if n_shots > 1, this many data_points are created by recombining n_shots number of examples


    Returns
    -------
    best_prompts
        List of the best found prompts
    metadata_dict
        Dictionary of metadata from fitting
    """
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    if not prefix_before_input:
        tokenizer.truncation_side = 'left'
    tokenizer.eos_token = tokenizer.eos_token or 0
    tokenizer.pad_token = tokenizer.eos_token

    # load the model (unless already loaded)
    def load_lm(checkpoint, tokenizer, llm_float16):
        if llm_float16:
            if checkpoint == "EleutherAI/gpt-j-6B":
                lm = AutoModelForCausalLM.from_pretrained(
                    checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id,
                    revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
                )
            else:
                # (only certain models are pre-float16ed)
                print(f"trying to convert {checkpoint} to float16...")
                lm = transformers.AutoModelForCausalLM.from_pretrained(
                    checkpoint, torch_dtype=torch.float16
                )
                lm = lm.half()
        else:
            lm = AutoModelForCausalLM.from_pretrained(
                checkpoint, output_hidden_states=False, pad_token_id=tokenizer.eos_token_id
            )
        return lm
    if lm is None:
        lm = load_lm(checkpoint, tokenizer, llm_float16)
        
    loss_func = PrefixLoss(gamma=gamma, tokenizer=tokenizer)

    if model_cls == 'iprompt':
        model = iPrompt(
            loss_func=loss_func,
            model=lm,
            tokenizer=tokenizer,
            preprefix_str=preprefix,
            pop_size=pop_size,
            pop_criterion=pop_criterion,
            pop_topk_strategy=pop_topk_strategy,
            num_mutations=num_mutations,
            prefix_before_input=prefix_before_input,
            do_final_reranking=do_final_reranking,
            num_random_generations=num_random_generations,
            generation_repetition_penalty=generation_repetition_penalty,
            generation_temp=generation_temp,
            generation_top_p=generation_top_p,
            early_stopping_steps=early_stopping_steps,
            num_learned_tokens=num_learned_tokens,
            max_length=max_length,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            verbose=verbose,
            llm_float16=llm_float16,
            generation_checkpoint=generation_checkpoint,
            llm_candidate_regeneration_prompt_start=llm_candidate_regeneration_prompt_start,
            llm_candidate_regeneration_prompt_end=llm_candidate_regeneration_prompt_end,
        )
    else:
        pass
        """
        model = model_cls_dict[model_cls](
            args=args,
            loss_func=loss_func, model=lm, tokenizer=tokenizer, preprefix=preprefix
        )
        """
    
    iprompt_local = len(llm_api) == 0
    
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    r = defaultdict(list)
    if iprompt_local:
        r = run_iprompt_local(
            r=r,
            input_strs=input_strings,
            output_strs=output_strings,
            model=model,
            tokenizer=tokenizer,
            save_dir=save_dir,
            lr=lr,
            batch_size=batch_size,
            max_length=max_length,
            mask_possible_answers=mask_possible_answers,
            n_epochs=n_epochs,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            accum_grad_over_epoch=accum_grad_over_epoch,
            max_n_datapoints=max_n_datapoints,
            max_n_steps=max_n_steps,
            epoch_save_interval=epoch_save_interval,
            verbose=verbose,
        )
    else:
        r = run_iprompt_api(
            r=r,
            input_strs=input_strings,
            output_strs=output_strings,
            model=model,
            tokenizer=tokenizer,
            save_dir=save_dir,
            lr=lr,
            batch_size=batch_size,
            max_length=max_length,
            mask_possible_answers=mask_possible_answers,
            n_epochs=n_epochs,
            n_shots=n_shots,
            single_shot_loss=single_shot_loss,
            accum_grad_over_epoch=accum_grad_over_epoch,
            max_n_datapoints=max_n_datapoints,
            max_n_steps=max_n_steps,
            epoch_save_interval=epoch_save_interval,
            verbose=verbose,
            llm_api=llm_api,
        )
    model = model.cpu()
    return r['prefixes'], r

    # r = eval_model(args=args, r=r, dset=Dataset.from_dict(dset_test[:128]), model=model, tokenizer=tokenizer)

Explain the relationship between the input strings and the output strings

Parameters

input_strings : List[str]
list of input strings (e.g. "2 + 2")
output_strings : List[str]
list of output strings (e.g. "4")
checkpoint : str
name of model checkpoint to prompt (e.g. EleutherAI/gpt-j-6B)
generation_checkpoint : str
name of model to generate prompts, only if if different from checkpoint used for prompting. defaults to '' (same model for both).
prefix_before_input : bool
whether to prompt the LLM with the prefix before or after the input data
do_final_reranking : bool
optionally rerank top prefixes using a single batch. helps prevent noisy prefixes from being top at the end, especially when run over a small number of iterations or with small batch size.
generation_temp : float
temperature for sampling from LLM (defaults to 1.0)
generation_top_p : float
p for sampling from LLM, if using top-p sampling (defaults to 1.0, no sampling)
num_learned_tokens : int
number of tokens to learn in prompt
save_dir : str
directory to save results
lr : float
learning rate for prompt tuning
pop_size : int
number of prompt candidates to evaluate for each iteration of iprompt
pop_criterion : str
criterion for getting top prefixes from prefix pool, in ['loss', 'acc']
pop_topk_strategy : str
strategy for getting new prefixes from prefix pool, in ['different_start_token', 'all']
num_mutations : int
number of mutations to apply to each prompt candidate
lm : transformers.PreTrainedModel
pre-loaded model (overrides checkpoint)
max_n_data_points : int
maximum number of data points to use for training if n_shots > 1, this many data_points are created by recombining n_shots number of examples

Returns

best_prompts
List of the best found prompts
metadata_dict
Dictionary of metadata from fitting
def get_prompts_api(data: List[str], llm: Callable, prompt_template: str)
Expand source code
def get_prompts_api(
        data: List[str], 
        llm: Callable,
        prompt_template: str, 
    ):
    data_str = random.choice(data)
    prompt = prompt_template(data=data_str).strip()
    answer = llm(prompt, max_new_tokens=24)
    return [answer]
def run_iprompt_api(r: Dict[str, List],
input_strs: List[str],
output_strs: List[str],
model: PrefixModel,
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer,
llm_api: str,
save_dir: str = 'results',
lr: float = 0.0001,
batch_size: int = 64,
max_length: int = 128,
n_epochs: int = 100,
n_shots: int = 1,
single_shot_loss: bool = True,
accum_grad_over_epoch: bool = False,
max_n_datapoints: int = 10000,
max_n_steps: int = 10000,
epoch_save_interval: int = 1,
mask_possible_answers: bool = False,
verbose: int = 0)
Expand source code
def run_iprompt_api(
    r: Dict[str, List],
    input_strs: List[str],
    output_strs: List[str],
    model: PrefixModel,
    tokenizer: transformers.PreTrainedTokenizer,
    llm_api: str,
    save_dir: str = 'results',
    lr: float = 1e-4,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    verbose: int = 0,
):
    """
    Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

    Params
    ------
    r: dict
        dictionary of things to save
    """


    # remove periods and newlines from the output so we actually use the tokens
    # for the reranking step in iPrompt
    output_strs = [s.rstrip().rstrip('.') for s in output_strs]

    r['train_start_time'] = time.time()
    model.train()

    logging.info("beginning iPrompt with n_shots = %d", n_shots)

    assert len(input_strs) == len(
        output_strs), "input and output must be same length to create input-output pairs"
    text_strs = list(map(lambda t: f'{t[0]}{t[1]}.', zip(input_strs, output_strs)))
    df = pd.DataFrame.from_dict({
        'input': input_strs,
        'output': output_strs,
        'text': text_strs,
    })
    if n_shots > 1:
        d2 = defaultdict(list)
        for i in range(max_n_datapoints):
            all_shots = df.sample(n=n_shots, replace=False)
            d2['text'].append('\n\n'.join(all_shots['text'].values))
            #
            last_input = all_shots.tail(n=1)['input'].values[0]
            d2['input'].append(
                ''.join(all_shots['text'].values[:-1]) + last_input)
            d2['last_input'].append(last_input)
            #
            last_output = all_shots.tail(n=1)['output'].values[0]
            d2['output'].append(last_output)
            #
        df = pd.DataFrame.from_dict(d2)
    
    # shuffle rows
    if max_n_datapoints < len(df):
        df = df.sample(n=max_n_datapoints, replace=False)
    dset = datasets.Dataset.from_pandas(df)
    dset.shuffle()
    print(f'iPrompt got {len(dset)} datapoints, now loading model...')

    model = model.to(device)
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=True, drop_last=False)

    prompt_template = "{prompt_start}\n\n{data}\n\n{prompt_end}"
    prompt_template = functools.partial(
        prompt_template.format,
        prompt_start=model.llm_candidate_regeneration_prompt_start,
        prompt_end=model.llm_candidate_regeneration_prompt_end,
    )
    
    prompts = []
    # "gpt-3.5-turbo", "text-curie-001"
    llm = get_llm( 
        checkpoint=llm_api, role="user")
    stopping_early = False
    total_n = 0
    total_n_steps = 0
    total_n_datapoints = 0
    for epoch in range(n_epochs):       
        print(f'Beginning epoch {epoch}')
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for idx, batch in pbar:
            total_n_steps += 1
            if (n_shots > 1) and (single_shot_loss):
                batch['input'] = batch['last_input']
            x_text, y_text = model.prepare_batch(batch=batch)

            tok = functools.partial(
                model.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=max_length)
            text_tokenized = tok(batch['text']).to(device)
            text_detokenized = model.tokenizer.batch_decode(
                text_tokenized['input_ids'], 
                skip_special_tokens=True,
            )
            
            prompts.extend(
                get_prompts_api(
                    data=text_detokenized, 
                    llm=llm, 
                    prompt_template=prompt_template, 
                )
            )

            total_n += len(x_text)
            total_n_datapoints += len(x_text)
            if (total_n_datapoints > max_n_datapoints) or (total_n_steps > max_n_steps):
                stopping_early = True
                break

        if stopping_early:
            print(f"Ending epoch {epoch} early...")

        # save stuff
        for key, val in model.compute_metrics().items():
            r[key].append(val)

        # r['losses'].append(avg_loss)
        if epoch % epoch_save_interval == 0:
            os.makedirs(save_dir, exist_ok=True)
            pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

        # Early stopping, check after epoch
        if stopping_early:
            print(
                f"Stopping early after {total_n_steps} steps and {total_n_datapoints} datapoints")
            break

    
    # 
    #   Evaluate model on prefixes
    # 

    # Compute loss only over possible answers to make task easier
    possible_answer_ids = []
    for batch in dataloader:
        y_text = [answer for answer in batch['output']]
        y_tokenized = tokenizer(y_text, return_tensors='pt', padding='longest')
        # only test on the single next token
        true_next_token_ids = y_tokenized['input_ids'][:, 0]
        possible_answer_ids.extend(true_next_token_ids.tolist())
    
    possible_answer_ids = torch.tensor(possible_answer_ids)
    vocab_size = len(tokenizer.vocab)
    possible_answer_mask = (
            torch.arange(start=0, end=vocab_size)[:, None]
            ==
            possible_answer_ids[None, :]
        ).any(dim=1).to(device)
    n_eval = 256
    eval_dset = datasets.Dataset.from_dict(dset[:n_eval])
    eval_dataloader = DataLoader(
        eval_dset, batch_size=batch_size, shuffle=True, drop_last=False)   
    all_prefixes = model.tokenizer(
        [f" {prompt.strip()}" for prompt in prompts], 
        truncation=False, 
        padding=False,
    )["input_ids"]
    all_losses, all_accuracies = model.test_prefixes(
        prefixes=all_prefixes,
        eval_dataloader=eval_dataloader,
        possible_answer_mask=possible_answer_mask
    )

    # 
    #   Store prefix info 
    # 
    df = pd.DataFrame(
        zip(*[all_prefixes, all_losses, all_accuracies]),
        columns=['prefix', 'loss', 'accuracy']
    )
    df = df.sort_values(by=['accuracy', 'loss'], ascending=[
                        False, True]).reset_index()
    df = df.sort_values(by='accuracy', ascending=False).reset_index()

    df['prefix_str'] = df['prefix'].map(
        functools.partial(model.tokenizer.decode, skip_special_tokens=True)
    )

    print('Final prefixes')
    print(df.head())
    r.update({
            "prefix_ids": df['prefix'].tolist(),
            "prefixes": df['prefix_str'].tolist(),
            "prefix_train_acc": df['accuracy'].tolist(),
            "prefix_train_loss": df['loss'].tolist(),
        })

    r['train_end_time'] = time.time()
    r['train_time_elapsed'] = r['train_end_time'] - r['train_start_time']

    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

    return r

Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

Params

r: dict dictionary of things to save

def run_iprompt_local(r: Dict[str, List],
input_strs: List[str],
output_strs: List[str],
model: PrefixModel,
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer,
save_dir: str = 'results',
lr: float = 0.0001,
batch_size: int = 64,
max_length: int = 128,
n_epochs: int = 100,
n_shots: int = 1,
single_shot_loss: bool = True,
accum_grad_over_epoch: bool = False,
max_n_datapoints: int = 10000,
max_n_steps: int = 10000,
epoch_save_interval: int = 1,
mask_possible_answers: bool = False,
verbose: int = 0)
Expand source code
def run_iprompt_local(
    r: Dict[str, List],
    input_strs: List[str],
    output_strs: List[str],
    model: PrefixModel,
    tokenizer: transformers.PreTrainedTokenizer,
    save_dir: str = 'results',
    lr: float = 1e-4,
    batch_size: int = 64,
    max_length: int = 128,
    n_epochs: int = 100,
    n_shots: int = 1,
    single_shot_loss: bool = True,
    accum_grad_over_epoch: bool = False,
    max_n_datapoints: int = 10**4,
    max_n_steps: int = 10**4,
    epoch_save_interval: int = 1,
    mask_possible_answers: bool = False,
    verbose: int = 0,
):
    """
    Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

    Params
    ------
    r: dict
        dictionary of things to save
    """

    # remove periods and newlines from the output so we actually use the tokens
    # for the reranking step in iPrompt
    output_strs = [s.rstrip().rstrip('.') for s in output_strs]

    r['train_start_time'] = time.time()
    model.train()

    assert len(input_strs) == len(
        output_strs), "input and output must be same length to create input-output pairs"
    text_strs = list(map(lambda t: f'{t[0]}{t[1]}.', zip(input_strs, output_strs)))
    df = pd.DataFrame.from_dict({
        'input': input_strs,
        'output': output_strs,
        'text': text_strs,
    })
    if n_shots > 1:
        d2 = defaultdict(list)
        for i in range(max_n_datapoints):
            all_shots = df.sample(n=n_shots, replace=False)
            d2['text'].append('\n\n'.join(all_shots['text'].values))
            #
            last_input = all_shots.tail(n=1)['input'].values[0]
            d2['input'].append(
                ''.join(all_shots['text'].values[:-1]) + last_input)
            d2['last_input'].append(last_input)
            #
            last_output = all_shots.tail(n=1)['output'].values[0]
            d2['output'].append(last_output)
            #
        df = pd.DataFrame.from_dict(d2)
    # shuffle rows
    if max_n_datapoints < len(df):
        df = df.sample(n=max_n_datapoints, replace=False)
    dset = datasets.Dataset.from_pandas(df)
    dset.shuffle()
    print(f'iPrompt got {len(dset)} datapoints, now loading model...')

    model = model.to(device)
    dataloader = DataLoader(
        dset, batch_size=batch_size, shuffle=True, drop_last=False)

    # optimizer
    optim = torch.optim.AdamW(model.trainable_params, lr=lr)

    assert model.training

    # Compute loss only over possible answers to make task easier
    possible_answer_ids = []
    for batch in dataloader:
        y_text = [answer for answer in batch['output']]
        y_tokenized = tokenizer(y_text, return_tensors='pt', padding='longest')
        # only test on the single next token
        true_next_token_ids = y_tokenized['input_ids'][:, 0]
        possible_answer_ids.extend(true_next_token_ids.tolist())

    possible_answer_ids = torch.tensor(possible_answer_ids)
    num_unique_answers = len(set(possible_answer_ids.tolist()))
    assert num_unique_answers > 0, "need multiple answers for multiple choice"
    random_acc = 1 / num_unique_answers * 100.0
    majority_count = (
        possible_answer_ids[:, None] == possible_answer_ids[None, :]).sum(dim=1).max()
    majority_acc = majority_count * 100.0 / len(possible_answer_ids)
    print(
        f"Training with {num_unique_answers} possible answers / random acc {random_acc:.1f}% / majority acc {majority_acc:.1f}%")

    vocab_size = len(tokenizer.vocab)

    if mask_possible_answers:
        possible_answer_mask = (
            torch.arange(start=0, end=vocab_size)[:, None]
            ==
            possible_answer_ids[None, :]
        ).any(dim=1).to(device)
    else:
        possible_answer_mask = None

    stopping_early = False
    total_n_steps = 0
    total_n_datapoints = 0
    for epoch in range(n_epochs):
        model.pre_epoch()

        all_losses = []

        total_n = 0
        total_n_correct = 0
        print(f'Beginning epoch {epoch}')
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for idx, batch in pbar:
            total_n_steps += 1
            if (n_shots > 1) and (single_shot_loss):
                batch['input'] = batch['last_input']
            x_text, y_text = model.prepare_batch(batch=batch)

            tok = functools.partial(
                model.tokenizer, return_tensors='pt', padding='longest',
                truncation=True, max_length=max_length)
            x_tokenized = tok(x_text).to(device)
            y_tokenized = tok(y_text).to(device)
            full_text_tokenized = tok(batch['text']).to(device)

            loss, n_correct = model.compute_loss_and_call_backward(
                x_tokenized=x_tokenized,
                y_tokenized=y_tokenized,
                possible_answer_mask=possible_answer_mask,
                full_text_tokenized=full_text_tokenized,
            )

            r["all_losses"].append(loss)
            r["all_n_correct"].append(n_correct)

            total_n += len(x_text)
            total_n_datapoints += len(x_text)
            total_n_correct += n_correct

            all_losses.append(loss)
            pbar.set_description(f"Loss = {loss:.3f}")

            if not accum_grad_over_epoch:
                # if hotflip, autoprompt, etc., grad will be zero
                optim.step()
                optim.zero_grad()

            # Early stopping, check after step
            model_check_early_stop = model.check_early_stop()
            if model_check_early_stop:
                print("model_check_early_stop returned true")
            if (total_n_datapoints > max_n_datapoints) or (total_n_steps > max_n_steps) or model_check_early_stop:
                stopping_early = True
                break

        if stopping_early:
            print(f"Ending epoch {epoch} early...")
        avg_loss = sum(all_losses) / len(all_losses)
        print(f"Epoch {epoch}. average loss = {avg_loss:.3f} / {total_n_correct} / {total_n} correct ({total_n_correct/total_n*100:.2f}%)")

        # save stuff
        for key, val in model.compute_metrics().items():
            r[key].append(val)

        # r['losses'].append(avg_loss)
        if epoch % epoch_save_interval == 0:
            os.makedirs(save_dir, exist_ok=True)
            pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

        model.post_epoch(dataloader=dataloader,
                         possible_answer_mask=possible_answer_mask)

        if accum_grad_over_epoch:
            optim.step()
            optim.zero_grad()

        # Early stopping, check after epoch
        if stopping_early:
            print(
                f"Stopping early after {total_n_steps} steps and {total_n_datapoints} datapoints")
            break

    # Serialize model-specific stuff (prefixes & losses for autoprompt, embeddings for prompt tuning, etc.)
    n_eval = 256
    eval_dset = datasets.Dataset.from_dict(dset[:n_eval])
    eval_dataloader = DataLoader(
        eval_dset, batch_size=batch_size, shuffle=True, drop_last=False)
    r.update(model.serialize(eval_dataloader, possible_answer_mask))
    # r.update(model.serialize())

    # save whether prefixes fit the template
    """
    if "prefixes" in r:
        r["prefixes__check_answer_func"] = list(
            map(check_answer_func, r["prefixes"]))
    """

    r['train_end_time'] = time.time()
    r['train_time_elapsed'] = r['train_end_time'] - r['train_start_time']

    pkl.dump(r, open(os.path.join(save_dir, 'results.pkl'), 'wb'))

    return r

Trains a model, either by optimizing continuous embeddings or finding an optimal discrete embedding.

Params

r: dict dictionary of things to save