Module imodelsx.sasc.api

Expand source code
from typing import List, Callable, Tuple, Dict
import imodelsx.sasc.m1_ngrams
import imodelsx.sasc.m2_summarize
import imodelsx.sasc.m3_generate
import numpy as np
import pprint
from collections import defaultdict


def explain_module_sasc(
    # get ngram module responses
    text_str_list: List[str],
    mod: Callable[[List[str]], List[float]],
    ngrams: int = 3,
    all_ngrams: bool = True,
    num_top_ngrams: int = 75,
    use_cache: bool = True,
    cache_filename: str = None,
    # generate explanation candidates
    llm_checkpoint: str = "text-davinci-003",
    llm_cache_dir: str = ".llm_cache",
    num_summaries: int = 3,
    num_top_ngrams_to_use: int = 30,
    num_top_ngrams_to_consider: int = 50,
    # generate synthetic strs
    num_synthetic_strs: int = 20,
    seed: int = 0,
    verbose: bool = True,
) -> Dict[str, List]:
    """

    Parameters
    ----------
    text_str_list: List[str]
        The list of text strings to use to extract ngrams
    mod: Callable[[List[str]], List[float]]
        The module to interpret
    ngrams: int
        The order of ngrams to use (3 is trigrams)
    all_ngrams: bool
        If True, use all ngrams up to ngrams. If False, use only ngrams
    num_top_ngrams: int
        The number of top ngrams to return
    use_cache: bool
        If True, use the cache
    cache_filename: str
        The filename to use for the module ngram cache
    llm_checkpoint: str
        The checkpoint to use for the llm
    llm_cache_dir: str
        The cache directory to use for the llm
    num_summaries: int
        The number of candidate explanations to generate
    num_top_ngrams_to_use: int
        The number of top ngrams to select
    num_top_ngrams_to_consider: int
        The number of top ngrams to consider selecting from
    num_synthetic_strs: int
        The number of synthetic strs to generate
    verbose: bool
        If True, print out progress
    seed: int
        The seed to use for the random number generator

    Returns
    -------
    explanation_dict: Dict[str, List]
        top_explanation_str: str
            The top explanation str
        top_explanation_score: float
            The top explanation score
        explanation_strs: List[str]
            The list of candidate explanation strs (this may have less entries than num_summaries if duplicate explanations are generated)
        explanation_scores: List[float]
            The list of corresponding candidate explanation scores
        ngrams_list: List[str]
            The list of top ngrams
        ngrams_scores: List[float]
            The list of top ngram scores
        strs_relevant: List[List[str]]
            The list of synthetically generated relevant strs
        strs_irrelevant: List[List[str]]
            The list of synthetically generated irrelevant strs
    """

    explanation_dict = defaultdict(list)

    # compute scores for each ngram
    (
        ngrams_list,
        ngrams_scores,
    ) = imodelsx.sasc.m1_ngrams.explain_ngrams(
        text_str_list=text_str_list,
        mod=mod,
        ngrams=ngrams,
        all_ngrams=all_ngrams,
        num_top_ngrams=num_top_ngrams,
        use_cache=use_cache,
        cache_filename=cache_filename,
    )
    explanation_dict["ngrams_list"] = ngrams_list
    explanation_dict["ngrams_scores"] = ngrams_scores

    # compute explanation candidates
    llm = imodelsx.llm.get_llm(llm_checkpoint, llm_cache_dir)
    (
        explanation_strs,
        _,
    ) = imodelsx.sasc.m2_summarize.summarize_ngrams(
        llm,
        ngrams_list,
        num_summaries=num_summaries,
        num_top_ngrams_to_use=num_top_ngrams_to_use,
        num_top_ngrams_to_consider=num_top_ngrams_to_consider,
        seed=seed,
    )
    explanation_dict["explanation_strs"] = explanation_strs

    # score explanation candidates on synthetic data
    for explanation_str in explanation_strs:
        strs_rel, strs_irrel = imodelsx.sasc.m3_generate.generate_synthetic_strs(
            llm,
            explanation_str=explanation_str,
            num_synthetic_strs=num_synthetic_strs,
            verbose=verbose,
        )
        explanation_dict["strs_relevant"].append(strs_rel)
        explanation_dict["strs_irrelevant"].append(strs_irrel)

        # evaluate synthetic data (higher score is better)
        explanation_dict["explanation_scores"].append(
            np.mean(mod(strs_rel)) - np.mean(mod(strs_irrel))
        )

    # sort everything by scores
    sort_inds = np.argsort(explanation_dict["explanation_scores"])[::-1]
    ks = list(explanation_dict.keys())
    for k in [
        "explanation_strs",
        "explanation_scores",
        "strs_relevant",
        "strs_irrelevant",
    ]:
        explanation_dict[k] = [explanation_dict[k][i] for i in sort_inds]
    for k in ["explanation_strs", "explanation_scores"]:
        explanation_dict["top_" + k[:-1]] = explanation_dict[k][0]

    return explanation_dict


if __name__ == "__main__":
    # an overly simple example of a module that responds to the length of a string
    mod = lambda str_list: np.array([len(s) for s in str_list])
    # in this dataset the longest strings happen to be animals, so we are searching for the explanation "animals"
    text_str_list = [
        "red",
        "blue",
        "x",
        "1",
        "2",
        "hippopotamus",
        "elephant",
        "rhinoceros",
    ]
    explanation_dict = explain_module_sasc(
        text_str_list,
        mod,
        ngrams=1,
        num_summaries=2,
        num_top_ngrams=3,
        num_top_ngrams_to_consider=3,
        num_synthetic_strs=2,
    )
    pprint.pprint(explanation_dict)

Functions

def explain_module_sasc(text_str_list: List[str], mod: Callable[[List[str]], List[float]], ngrams: int = 3, all_ngrams: bool = True, num_top_ngrams: int = 75, use_cache: bool = True, cache_filename: str = None, llm_checkpoint: str = 'text-davinci-003', llm_cache_dir: str = '.llm_cache', num_summaries: int = 3, num_top_ngrams_to_use: int = 30, num_top_ngrams_to_consider: int = 50, num_synthetic_strs: int = 20, seed: int = 0, verbose: bool = True) ‑> Dict[str, List]

Parameters

text_str_list : List[str]
The list of text strings to use to extract ngrams
mod : Callable[[List[str]], List[float]]
The module to interpret
ngrams : int
The order of ngrams to use (3 is trigrams)
all_ngrams : bool
If True, use all ngrams up to ngrams. If False, use only ngrams
num_top_ngrams : int
The number of top ngrams to return
use_cache : bool
If True, use the cache
cache_filename : str
The filename to use for the module ngram cache
llm_checkpoint : str
The checkpoint to use for the llm
llm_cache_dir : str
The cache directory to use for the llm
num_summaries : int
The number of candidate explanations to generate
num_top_ngrams_to_use : int
The number of top ngrams to select
num_top_ngrams_to_consider : int
The number of top ngrams to consider selecting from
num_synthetic_strs : int
The number of synthetic strs to generate
verbose : bool
If True, print out progress
seed : int
The seed to use for the random number generator

Returns

explanation_dict : Dict[str, List]
top_explanation_str: str The top explanation str top_explanation_score: float The top explanation score explanation_strs: List[str] The list of candidate explanation strs (this may have less entries than num_summaries if duplicate explanations are generated) explanation_scores: List[float] The list of corresponding candidate explanation scores ngrams_list: List[str] The list of top ngrams ngrams_scores: List[float] The list of top ngram scores strs_relevant: List[List[str]] The list of synthetically generated relevant strs strs_irrelevant: List[List[str]] The list of synthetically generated irrelevant strs
Expand source code
def explain_module_sasc(
    # get ngram module responses
    text_str_list: List[str],
    mod: Callable[[List[str]], List[float]],
    ngrams: int = 3,
    all_ngrams: bool = True,
    num_top_ngrams: int = 75,
    use_cache: bool = True,
    cache_filename: str = None,
    # generate explanation candidates
    llm_checkpoint: str = "text-davinci-003",
    llm_cache_dir: str = ".llm_cache",
    num_summaries: int = 3,
    num_top_ngrams_to_use: int = 30,
    num_top_ngrams_to_consider: int = 50,
    # generate synthetic strs
    num_synthetic_strs: int = 20,
    seed: int = 0,
    verbose: bool = True,
) -> Dict[str, List]:
    """

    Parameters
    ----------
    text_str_list: List[str]
        The list of text strings to use to extract ngrams
    mod: Callable[[List[str]], List[float]]
        The module to interpret
    ngrams: int
        The order of ngrams to use (3 is trigrams)
    all_ngrams: bool
        If True, use all ngrams up to ngrams. If False, use only ngrams
    num_top_ngrams: int
        The number of top ngrams to return
    use_cache: bool
        If True, use the cache
    cache_filename: str
        The filename to use for the module ngram cache
    llm_checkpoint: str
        The checkpoint to use for the llm
    llm_cache_dir: str
        The cache directory to use for the llm
    num_summaries: int
        The number of candidate explanations to generate
    num_top_ngrams_to_use: int
        The number of top ngrams to select
    num_top_ngrams_to_consider: int
        The number of top ngrams to consider selecting from
    num_synthetic_strs: int
        The number of synthetic strs to generate
    verbose: bool
        If True, print out progress
    seed: int
        The seed to use for the random number generator

    Returns
    -------
    explanation_dict: Dict[str, List]
        top_explanation_str: str
            The top explanation str
        top_explanation_score: float
            The top explanation score
        explanation_strs: List[str]
            The list of candidate explanation strs (this may have less entries than num_summaries if duplicate explanations are generated)
        explanation_scores: List[float]
            The list of corresponding candidate explanation scores
        ngrams_list: List[str]
            The list of top ngrams
        ngrams_scores: List[float]
            The list of top ngram scores
        strs_relevant: List[List[str]]
            The list of synthetically generated relevant strs
        strs_irrelevant: List[List[str]]
            The list of synthetically generated irrelevant strs
    """

    explanation_dict = defaultdict(list)

    # compute scores for each ngram
    (
        ngrams_list,
        ngrams_scores,
    ) = imodelsx.sasc.m1_ngrams.explain_ngrams(
        text_str_list=text_str_list,
        mod=mod,
        ngrams=ngrams,
        all_ngrams=all_ngrams,
        num_top_ngrams=num_top_ngrams,
        use_cache=use_cache,
        cache_filename=cache_filename,
    )
    explanation_dict["ngrams_list"] = ngrams_list
    explanation_dict["ngrams_scores"] = ngrams_scores

    # compute explanation candidates
    llm = imodelsx.llm.get_llm(llm_checkpoint, llm_cache_dir)
    (
        explanation_strs,
        _,
    ) = imodelsx.sasc.m2_summarize.summarize_ngrams(
        llm,
        ngrams_list,
        num_summaries=num_summaries,
        num_top_ngrams_to_use=num_top_ngrams_to_use,
        num_top_ngrams_to_consider=num_top_ngrams_to_consider,
        seed=seed,
    )
    explanation_dict["explanation_strs"] = explanation_strs

    # score explanation candidates on synthetic data
    for explanation_str in explanation_strs:
        strs_rel, strs_irrel = imodelsx.sasc.m3_generate.generate_synthetic_strs(
            llm,
            explanation_str=explanation_str,
            num_synthetic_strs=num_synthetic_strs,
            verbose=verbose,
        )
        explanation_dict["strs_relevant"].append(strs_rel)
        explanation_dict["strs_irrelevant"].append(strs_irrel)

        # evaluate synthetic data (higher score is better)
        explanation_dict["explanation_scores"].append(
            np.mean(mod(strs_rel)) - np.mean(mod(strs_irrel))
        )

    # sort everything by scores
    sort_inds = np.argsort(explanation_dict["explanation_scores"])[::-1]
    ks = list(explanation_dict.keys())
    for k in [
        "explanation_strs",
        "explanation_scores",
        "strs_relevant",
        "strs_irrelevant",
    ]:
        explanation_dict[k] = [explanation_dict[k][i] for i in sort_inds]
    for k in ["explanation_strs", "explanation_scores"]:
        explanation_dict["top_" + k[:-1]] = explanation_dict[k][0]

    return explanation_dict