Module imodelsx.sasc.m3_generate

Expand source code
import re
from typing import Any, List, Mapping, Optional, Tuple, Callable
from imodelsx.llm import get_llm


def generate_synthetic_strs(
    llm: Callable[[str], str],
    explanation_str: str,
    num_synthetic_strs: int = 20,
    template_num: int = 0,
    verbose=True,
) -> Tuple[List[str], List[str]]:
    """Generate text_added and text_removed via call to an LLM.

    Params
    ------
    llm: Callable[[str], str]
        The llm to use
        flan-t5-xxl/opt-iml-max-30b can only generate one sentence before stopping
        EleutherAI/gpt-neox-20b can generate multiple sentences, but they are not faithful to the concept
    explanation_str: str
        The explanation string to use
    num_synthetic_strs: int
        The number of synthetic strings to generate
    template_num: int
        The prompt template number to use

    Returns
    -------
    strs_added: List[str]
        The list of synthetic strings with the explanation scores added
    strs_removed: List[str]
        The list of synthetic strings with the explanation scores removed
    """

    templates = [
        """
Generate {num_synthetic_strs} sentences that {blank_or_do_not}contain the concept of "{concept}":

1. The""",
        """
Generate {num_synthetic_strs} phrases that are {blank_or_do_not}similar to the concept of "{concept}":

1.""",
    ]
    blank_or_do_not_templates = [
        ["", "do not "],
        ["", "not "],
    ]
    template = templates[template_num]
    strs_added = []
    strs_removed = []
    for blank_or_do_not in blank_or_do_not_templates[template_num]:
        prompt = template.format(
            num_synthetic_strs=num_synthetic_strs,
            blank_or_do_not=blank_or_do_not,
            concept=explanation_str,
        )

        # note: this works works with openai model
        # but tends to stop after generating just one text with non-openai
        synthetic_text_numbered_str = llm(prompt, max_new_tokens=400, do_sample=True)
        if verbose:
            print("\n\n---------------\n")
            print(prompt)
            print("\n\n---------------\n")
            print(synthetic_text_numbered_str)
            print("\n\n---------------\n")

        # split the string s on any number followed by period like 1. or 2.
        synthetic_strs_split = re.split(r"\d.", synthetic_text_numbered_str)
        synthetic_strs_split = [s.strip() for s in synthetic_strs_split if s.strip()]
        synthetic_strs = []
        for i in range(len(synthetic_strs_split)):
            s = synthetic_strs_split[i]
            if s.startswith("."):
                s = s[1:]
            synthetic_strs.append(s.strip())
        synthetic_strs = [s for s in synthetic_strs if len(s) > 2]
        if verbose:
            print("synthetic_strs=", synthetic_strs)

        # ks = list(set(ks))  # remove duplicates
        # ks = [k.lower() for k in ks if len(k) > 2] # lowercase & len > 2
        # return ks
        # synthetic_str = synthetic_str.strip()
        # ....

        for s in synthetic_strs:
            if blank_or_do_not == "":
                strs_added.append(s)
            else:
                strs_removed.append(s)
    return strs_added, strs_removed


if __name__ == "__main__":
    # llm = get_llm(checkpoint='EleutherAI/gpt-neox-20b')
    llm = get_llm("text-davinci-003")
    strs_added, strs_removed = generate_synthetic_strs(
        llm,
        explanation_str="anger",
        num_synthetic_strs=20,
        template_num=1,
    )
    print(f"{strs_added=} {strs_removed=}")

Functions

def generate_synthetic_strs(llm: Callable[[str], str], explanation_str: str, num_synthetic_strs: int = 20, template_num: int = 0, verbose=True) ‑> Tuple[List[str], List[str]]

Generate text_added and text_removed via call to an LLM.

Params

llm: Callable[[str], str] The llm to use flan-t5-xxl/opt-iml-max-30b can only generate one sentence before stopping EleutherAI/gpt-neox-20b can generate multiple sentences, but they are not faithful to the concept explanation_str: str The explanation string to use num_synthetic_strs: int The number of synthetic strings to generate template_num: int The prompt template number to use

Returns

strs_added : List[str]
The list of synthetic strings with the explanation scores added
strs_removed : List[str]
The list of synthetic strings with the explanation scores removed
Expand source code
def generate_synthetic_strs(
    llm: Callable[[str], str],
    explanation_str: str,
    num_synthetic_strs: int = 20,
    template_num: int = 0,
    verbose=True,
) -> Tuple[List[str], List[str]]:
    """Generate text_added and text_removed via call to an LLM.

    Params
    ------
    llm: Callable[[str], str]
        The llm to use
        flan-t5-xxl/opt-iml-max-30b can only generate one sentence before stopping
        EleutherAI/gpt-neox-20b can generate multiple sentences, but they are not faithful to the concept
    explanation_str: str
        The explanation string to use
    num_synthetic_strs: int
        The number of synthetic strings to generate
    template_num: int
        The prompt template number to use

    Returns
    -------
    strs_added: List[str]
        The list of synthetic strings with the explanation scores added
    strs_removed: List[str]
        The list of synthetic strings with the explanation scores removed
    """

    templates = [
        """
Generate {num_synthetic_strs} sentences that {blank_or_do_not}contain the concept of "{concept}":

1. The""",
        """
Generate {num_synthetic_strs} phrases that are {blank_or_do_not}similar to the concept of "{concept}":

1.""",
    ]
    blank_or_do_not_templates = [
        ["", "do not "],
        ["", "not "],
    ]
    template = templates[template_num]
    strs_added = []
    strs_removed = []
    for blank_or_do_not in blank_or_do_not_templates[template_num]:
        prompt = template.format(
            num_synthetic_strs=num_synthetic_strs,
            blank_or_do_not=blank_or_do_not,
            concept=explanation_str,
        )

        # note: this works works with openai model
        # but tends to stop after generating just one text with non-openai
        synthetic_text_numbered_str = llm(prompt, max_new_tokens=400, do_sample=True)
        if verbose:
            print("\n\n---------------\n")
            print(prompt)
            print("\n\n---------------\n")
            print(synthetic_text_numbered_str)
            print("\n\n---------------\n")

        # split the string s on any number followed by period like 1. or 2.
        synthetic_strs_split = re.split(r"\d.", synthetic_text_numbered_str)
        synthetic_strs_split = [s.strip() for s in synthetic_strs_split if s.strip()]
        synthetic_strs = []
        for i in range(len(synthetic_strs_split)):
            s = synthetic_strs_split[i]
            if s.startswith("."):
                s = s[1:]
            synthetic_strs.append(s.strip())
        synthetic_strs = [s for s in synthetic_strs if len(s) > 2]
        if verbose:
            print("synthetic_strs=", synthetic_strs)

        # ks = list(set(ks))  # remove duplicates
        # ks = [k.lower() for k in ks if len(k) > 2] # lowercase & len > 2
        # return ks
        # synthetic_str = synthetic_str.strip()
        # ....

        for s in synthetic_strs:
            if blank_or_do_not == "":
                strs_added.append(s)
            else:
                strs_removed.append(s)
    return strs_added, strs_removed