Expand source code
def explain_module_sasc(
    # get ngram module responses
    text_str_list: List[str],
    mod: Callable[[List[str]], List[float]],
    ngrams: int = 3,
    all_ngrams: bool = True,
    num_top_ngrams: int = 75,
    use_cache: bool = True,
    cache_filename: str = None,
    # generate explanation candidates
    llm_checkpoint: str = "text-davinci-003",
    llm_cache_dir: str = ".llm_cache",
    num_summaries: int = 3,
    num_top_ngrams_to_use: int = 30,
    num_top_ngrams_to_consider: int = 50,
    # generate synthetic strs
    num_synthetic_strs: int = 20,
    seed: int = 0,
    verbose: bool = True,
) -> Dict[str, List]:
    """
    Parameters
    ----------
    text_str_list: List[str]
        The list of text strings to use to extract ngrams
    mod: Callable[[List[str]], List[float]]
        The module to interpret
    ngrams: int
        The order of ngrams to use (3 is trigrams)
    all_ngrams: bool
        If True, use all ngrams up to ngrams. If False, use only ngrams
    num_top_ngrams: int
        The number of top ngrams to return
    use_cache: bool
        If True, use the cache
    cache_filename: str
        The filename to use for the module ngram cache
    llm_checkpoint: str
        The checkpoint to use for the llm
    llm_cache_dir: str
        The cache directory to use for the llm
    num_summaries: int
        The number of candidate explanations to generate
    num_top_ngrams_to_use: int
        The number of top ngrams to select
    num_top_ngrams_to_consider: int
        The number of top ngrams to consider selecting from
    num_synthetic_strs: int
        The number of synthetic strs to generate
    verbose: bool
        If True, print out progress
    seed: int
        The seed to use for the random number generator
    Returns
    -------
    explanation_dict: Dict[str, List]
        top_explanation_str: str
            The top explanation str
        top_explanation_score: float
            The top explanation score
        explanation_strs: List[str]
            The list of candidate explanation strs (this may have less entries than num_summaries if duplicate explanations are generated)
        explanation_scores: List[float]
            The list of corresponding candidate explanation scores
        ngrams_list: List[str]
            The list of top ngrams
        ngrams_scores: List[float]
            The list of top ngram scores
        strs_relevant: List[List[str]]
            The list of synthetically generated relevant strs
        strs_irrelevant: List[List[str]]
            The list of synthetically generated irrelevant strs
    """
    explanation_dict = defaultdict(list)
    # compute scores for each ngram
    (
        ngrams_list,
        ngrams_scores,
    ) = imodelsx.sasc.m1_ngrams.explain_ngrams(
        text_str_list=text_str_list,
        mod=mod,
        ngrams=ngrams,
        all_ngrams=all_ngrams,
        num_top_ngrams=num_top_ngrams,
        use_cache=use_cache,
        cache_filename=cache_filename,
    )
    explanation_dict["ngrams_list"] = ngrams_list
    explanation_dict["ngrams_scores"] = ngrams_scores
    # compute explanation candidates
    llm = imodelsx.llm.get_llm(llm_checkpoint, llm_cache_dir)
    (
        explanation_strs,
        _,
    ) = imodelsx.sasc.m2_summarize.summarize_ngrams(
        llm,
        ngrams_list,
        num_summaries=num_summaries,
        num_top_ngrams_to_use=num_top_ngrams_to_use,
        num_top_ngrams_to_consider=num_top_ngrams_to_consider,
        seed=seed,
    )
    explanation_dict["explanation_strs"] = explanation_strs
    # score explanation candidates on synthetic data
    for explanation_str in explanation_strs:
        strs_rel, strs_irrel = imodelsx.sasc.m3_generate.generate_synthetic_strs(
            llm,
            explanation_str=explanation_str,
            num_synthetic_strs=num_synthetic_strs,
            verbose=verbose,
        )
        explanation_dict["strs_relevant"].append(strs_rel)
        explanation_dict["strs_irrelevant"].append(strs_irrel)
        # evaluate synthetic data (higher score is better)
        explanation_dict["explanation_scores"].append(
            np.mean(mod(strs_rel)) - np.mean(mod(strs_irrel))
        )
    # sort everything by scores
    sort_inds = np.argsort(explanation_dict["explanation_scores"])[::-1]
    ks = list(explanation_dict.keys())
    for k in [
        "explanation_strs",
        "explanation_scores",
        "strs_relevant",
        "strs_irrelevant",
    ]:
        explanation_dict[k] = [explanation_dict[k][i] for i in sort_inds]
    for k in ["explanation_strs", "explanation_scores"]:
        explanation_dict["top_" + k[:-1]] = explanation_dict[k][0]
    return explanation_dict