Module imodelsx.llm

Functions

def get_llm(checkpoint,
seed=1,
role: str = None,
repeat_delay: float | None = None,
CACHE_DIR='/home/chansingh/.CACHE_LLM')
Expand source code
def get_llm(
    checkpoint,
    seed=1,
    role: str = None,
    repeat_delay: Optional[float] = None,
    CACHE_DIR=LLM_CONFIG["CACHE_DIR"],
):
    if repeat_delay is not None:
        LLM_CONFIG["LLM_REPEAT_DELAY"] = repeat_delay

    """Get an LLM with a call function and caching capabilities"""
    if any(checkpoint.startswith(prefix) for prefix in ["gpt-3", "gpt-4", "o3", "o4", "gpt-5"]):
        return LLM_Chat(checkpoint, seed, role, CACHE_DIR)
    elif 'meta-llama' in checkpoint and 'Instruct' in checkpoint:
        if os.environ['HF_TOKEN'] is None:
            raise ValueError(
                "You must set the HF_TOKEN environment variable to use this model.")
        return LLM_HF_Pipeline(checkpoint, CACHE_DIR)
    else:
        # warning: this sets torch.manual_seed(seed)
        return LLM_HF(checkpoint, seed=seed, CACHE_DIR=CACHE_DIR)
def load_hf_model(checkpoint: str) ‑> transformers.modeling_utils.PreTrainedModel
Expand source code
def load_hf_model(checkpoint: str) -> transformers.PreTrainedModel:
    # set checkpoint
    kwargs = {
        "pretrained_model_name_or_path": checkpoint,
        "output_hidden_states": False,
        # "pad_token_id": tokenizer.eos_token_id,
        "low_cpu_mem_usage": True,
    }
    if "google/flan" in checkpoint:
        return T5ForConditionalGeneration.from_pretrained(
            checkpoint, device_map="auto", torch_dtype=torch.float16
        )
    elif checkpoint == "EleutherAI/gpt-j-6B":
        return AutoModelForCausalLM.from_pretrained(
            checkpoint,
            revision="float16",
            torch_dtype=torch.float16,
            **kwargs,
        )
    elif "llama-2" in checkpoint.lower():
        return AutoModelForCausalLM.from_pretrained(
            checkpoint,
            torch_dtype=torch.float16,
            device_map="auto",
            token=os.environ['HF_TOKEN'],
            offload_folder="offload",
        )
    elif 'microsoft/phi' in checkpoint:
        return AutoModelForCausalLM.from_pretrained(
            checkpoint
        )
    elif checkpoint == "gpt-xl":
        return AutoModelForCausalLM.from_pretrained(checkpoint)
    else:
        return AutoModelForCausalLM.from_pretrained(
            checkpoint,
            device_map="auto",
            torch_dtype=torch.float16,
            token=os.environ['HF_TOKEN'],
        )
def load_tokenizer(checkpoint: str) ‑> transformers.tokenization_utils.PreTrainedTokenizer
Expand source code
def load_tokenizer(checkpoint: str) -> transformers.PreTrainedTokenizer:
    if "facebook/opt" in checkpoint:
        # opt can't use fast tokenizer
        tokenizer = AutoTokenizer.from_pretrained(
            checkpoint, use_fast=False, padding_side='left', token=os.environ['HF_TOKEN'])
    elif "PMC_LLAMA" in checkpoint:
        tokenizer = transformers.LlamaTokenizer.from_pretrained(
            "chaoyi-wu/PMC_LLAMA_7B", padding_side='left', token=os.environ['HF_TOKEN'])
    else:
        # , use_fast=True)
        tokenizer = AutoTokenizer.from_pretrained(
            checkpoint, padding_side='left', use_fast=True, token=os.environ['HF_TOKEN'])

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return tokenizer
def repeatedly_call_with_delay(llm_call)
Expand source code
def repeatedly_call_with_delay(llm_call):
    def wrapper(*args, **kwargs):
        # Number of seconds to wait between calls (None will not repeat)
        delay = LLM_CONFIG["LLM_REPEAT_DELAY"]
        response = None
        while response is None:
            try:
                response = llm_call(*args, **kwargs)

                # fix for when this function was returning response rather than string
                # if response is not None and not isinstance(response, str):
                # response = response["choices"][0]["message"]["content"]
            except Exception as e:
                e = str(e)
                print(e)
                if "does not exist" in e:
                    return None
                elif "maximum context length" in e:
                    return None
                elif 'content management policy' in e:
                    return None
                if delay is None:
                    raise e
                else:
                    time.sleep(delay)
        return response

    return wrapper

Classes

class LLMEmbs (checkpoint)
Expand source code
class LLMEmbs:
    def __init__(self, checkpoint):
        self.tokenizer_ = load_tokenizer(checkpoint)
        self.model_ = AutoModel.from_pretrained(
            checkpoint, output_hidden_states=True,
            device_map="auto",
            torch_dtype=torch.float16,)

    def __call__(self, texts: List[str], layer_idx: int = 18, batch_size=16):
        '''Returns embeddings
        '''
        embs = []
        for i in tqdm(range(0, len(texts), batch_size)):
            inputs = self.tokenizer_(
                texts[i:i + batch_size], return_tensors='pt', padding=True).to(self.model_.device)
            hidden_states = self.model_(**inputs).hidden_states

            # layers x batch x tokens x features
            emb = hidden_states[layer_idx].detach().cpu().numpy()

            # get emb from last token
            emb = emb[:, -1, :]
            embs.append(deepcopy(emb))
        embs = np.concatenate(embs)
        return embs
class LLM_Chat (checkpoint, seed=1, role=None, CACHE_DIR='/home/chansingh/.CACHE_LLM')
Expand source code
class LLM_Chat:
    """Chat models take a different format: https://platform.openai.com/docs/guides/chat/introduction"""

    def __init__(self, checkpoint, seed=1, role=None, CACHE_DIR=LLM_CONFIG["CACHE_DIR"]):
        self.cache_dir = join(
            CACHE_DIR, "cache_openai", f'{checkpoint.replace("/", "_")}___{seed}'
        )
        self.checkpoint = checkpoint
        self.role = role
        from openai import AzureOpenAI
        from azure.identity import ChainedTokenCredential, AzureCliCredential, ManagedIdentityCredential, get_bearer_token_provider

        try:
            client_id = os.environ.get("AZURE_CLIENT_ID")
            scope = "https://cognitiveservices.azure.com/.default"
            credential = get_bearer_token_provider(ChainedTokenCredential(
                AzureCliCredential(), # first check local
                ManagedIdentityCredential(client_id=client_id)
            ), scope)
            if 'audio' in checkpoint:
                self.client = AzureOpenAI(
                    api_version="2025-01-01-preview",
                    azure_endpoint="https://neuroaiservice.cognitiveservices.azure.com/openai/deployments/gpt-4o-audio-preview/chat/completions?api-version=2025-01-01-preview",
                    azure_ad_token_provider=credential,
                    timeout=10,
                    max_retries=3,
                )
            elif 'gpt-5' in checkpoint:
                self.client = AzureOpenAI(
                    api_version="2025-01-01-preview",
                    azure_endpoint="https://dl-openai-3.openai.azure.com/",
                    azure_ad_token_provider=credential
                )
            else:
                self.client = AzureOpenAI(
                    api_version="2025-01-01-preview",
                    azure_endpoint="https://dl-openai-1.openai.azure.com/",
                    azure_ad_token_provider=credential
                )
        except Exception as e:
            print('failed to create client', e)
            print('You may need to edit this call in order to supply your own OpenAI / AzureOpenAI key and authentication.')
            traceback.print_exc()

    @repeatedly_call_with_delay
    def __call__(
        self,
        prompts_list: List[Dict[str, str]],
        max_completion_tokens=250,
        stop=None,
        functions: List[Dict] = None,
        return_str=True,
        verbose=True,
        temperature=0,
        frequency_penalty=0.25,
        use_cache=True,
        return_false_if_not_cached=False,
        reasoning_effort='high',
        seed=1,
    ):
        """
        prompts_list: list of dicts, each dict has keys 'role' and 'content'
            Example: [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "Who won the world series in 2020?"},
                {"role": "assistant",
                    "content": "The Los Angeles Dodgers won the World Series in 2020."},
                {"role": "user", "content": "Where was it played?"}
            ]
        prompts_list: str
            Alternatively, string which gets formatted into basic prompts_list:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": <<<<<prompts_list>>>>},
            ]
        """
        if isinstance(prompts_list, str):
            role = self.role
            if role is None:
                role = "You are a helpful assistant."
            prompts_list = [
                {"role": "system", "content": role},
                {"role": "user", "content": prompts_list},
            ]

        assert isinstance(prompts_list, list), prompts_list
        # breakpoint()

        # cache
        os.makedirs(self.cache_dir, exist_ok=True)
        prompts_list_dict = {
            str(i): sorted(v.items()) for i, v in enumerate(prompts_list)
        }
        prompts_list_dict["checkpoint"] = self.checkpoint
        prompts_list_dict["temperature"] = temperature
        prompts_list_dict["functions"] = functions
        prompts_list_dict["max_completion_tokens"] = max_completion_tokens
        if not seed == 1:
            prompts_list_dict["seed"] = seed
        if not reasoning_effort == 'high':
            prompts_list_dict['reasoning_effort'] = reasoning_effort
        
        dict_as_str = json.dumps(prompts_list_dict, sort_keys=True)
        hash_str = hashlib.sha256(dict_as_str.encode()).hexdigest()
        cache_file = join(
            self.cache_dir,
            f"chat__{hash_str}.pkl",
        )
        if os.path.exists(cache_file) and use_cache:
            if verbose:
                print("cached!")
                # print(cache_file)
            # print(cache_file)
            response = pkl.load(open(cache_file, "rb"))
            if response is not None:
                return response
        if verbose:
            print("not cached")

        if return_false_if_not_cached:
            return False

        kwargs = dict(
            model=self.checkpoint,
            messages=prompts_list,
            max_completion_tokens=max_completion_tokens,
            temperature=temperature,
            top_p=1,
            frequency_penalty=frequency_penalty,  # maximum is 2
            presence_penalty=0,
            stop=stop,
            reasoning_effort=reasoning_effort,
            # logprobs=True,
            # stop=["101"]
        )
        # print('kwargs', kwargs)
        if functions is not None:
            kwargs["functions"] = functions
        if self.checkpoint.startswith('o') or self.checkpoint == 'gpt-5':
            del kwargs['temperature']  # o3 and o4 don't support temperature
            del kwargs['frequency_penalty']
            del kwargs['top_p']
        
        if not 'gpt-5' in self.checkpoint:
            del kwargs['reasoning_effort']


        response = self.client.chat.completions.create(
            **kwargs,
        )

        if return_str:
            response = response.choices[0].message.content

        if response is not None:
            # print('resp', response, 'cache_file', cache_file)
            try:
                # print(cache_file, 'cached!')
                pkl.dump(response, open(cache_file, "wb"))
            except:
                print('failed to save cache!', cache_file)
                traceback.print_exc()

        return response

Subclasses

class LLM_Chat_Audio (checkpoint, seed=1, role=None, CACHE_DIR='/home/chansingh/.CACHE_LLM')
Expand source code
class LLM_Chat_Audio(LLM_Chat):

    def __call__(
        self,
        prompt_str: str,
        audio_str: str,
        return_str=True,
        verbose=True,
        use_cache=True,
        return_false_if_not_cached=False,
    ):

        # breakpoint()
        prompts_list = [{
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt_str,
                    # "text": 'What food is mentioned in the recording?',
                },
                {
                    "type": "input_audio",
                    "input_audio": {
                        "data": audio_str,
                        "format": "wav"
                    }
                }
            ]
        }]

        # cache
        os.makedirs(self.cache_dir, exist_ok=True)
        prompts_list_dict = {
            str(i): sorted(v.items()) for i, v in enumerate(prompts_list)
        }
        dict_as_str = json.dumps(prompts_list_dict, sort_keys=True)
        hash_str = hashlib.sha256(dict_as_str.encode()).hexdigest()
        cache_file = join(
            self.cache_dir,
            f"audio__{hash_str}.pkl",
        )
        if os.path.exists(cache_file) and use_cache:
            if verbose:
                print("cached!")
                # print(cache_file)
            # print(cache_file)
            response = pkl.load(open(cache_file, "rb"))
            if response is not None:
                return response
        if verbose:
            print("not cached")

        if return_false_if_not_cached:
            return False

        kwargs = dict(
            model=self.checkpoint,
            messages=prompts_list,
            # temperature=0,
        )
        response = self.client.chat.completions.create(
            modalities=["text", "audio"],
            audio={"voice": "alloy", "format": "wav"},
            **kwargs,
        )

        if return_str:
            response = response.choices[0].message.audio.transcript

        if response is not None:
            # print('resp', response, 'cache_file', cache_file)
            try:
                # print(cache_file, 'cached!')
                pkl.dump(response, open(cache_file, "wb"))
            except:
                print('failed to save cache!', cache_file)
                traceback.print_exc()

        return response

Ancestors

class LLM_HF (checkpoint, seed, CACHE_DIR)
Expand source code
class LLM_HF:
    def __init__(self, checkpoint, seed, CACHE_DIR):
        self.tokenizer_ = load_tokenizer(checkpoint)
        self.model_ = load_hf_model(checkpoint)
        self.checkpoint = checkpoint
        if CACHE_DIR is not None:
            self.cache_dir = join(
                CACHE_DIR, "cache_hf", f'{checkpoint.replace("/", "_")}___{seed}'
            )
        else:
            self.cache_dir = None
        self.seed = seed

    def __call__(
        self,
        prompt: Union[str, List[str]],
        stop: str = None,
        max_new_tokens=20,
        do_sample=False,
        use_cache=True,
        verbose=False,
        return_next_token_prob_scores=False,
        target_token_strs: List[str] = None,
        return_top_target_token_str: bool = False,
        batch_size=1,
    ) -> Union[str, List[str]]:
        """Warning: stop is used posthoc but not during generation.
        Be careful, caching can take up a lot of memory....

        Example mistral-instruct prompt: "<s>[INST]'Input text: {example}\nQuestion: {question} Answer yes or no.[/INST]"


        Params
        ------
        return_next_token_prob_scores: bool
            If this is true, then the function will return the probability of the next token being each of the target_token_strs
            target_token_strs: List[str]
                If this is not None and return_next_token_prob_scores is True, then the function will return the probability of the next token being each of the target_token_strs
                The output will be a list of dictionaries in this case List[Dict[str, float]]
                return_top_target_token_str: bool
                    If true and above are true, then just return top token of the above
                    This is a way to constrain the output (but only for 1 token)
                    This setting caches but the other two (which do not return strings) do not cache

        """
        input_is_str = isinstance(prompt, str)
        with torch.no_grad():
            use_cache = use_cache and self.cache_dir is not None
            # cache
            if use_cache:
                os.makedirs(self.cache_dir, exist_ok=True)
                hash_str = hashlib.sha256(str(prompt).encode()).hexdigest()
                cache_file = join(
                    self.cache_dir, f"{hash_str}__num_tok={max_new_tokens}.pkl"
                )

                if os.path.exists(cache_file):
                    if verbose:
                        print("cached!")
                    try:
                        return pkl.load(open(cache_file, "rb"))
                    except:
                        print('failed to load cache so rerunning...')
                if verbose:
                    print("not cached...")

            # if stop is not None:
            # raise ValueError("stop kwargs are not permitted.")
            inputs = self.tokenizer_(
                prompt, return_tensors="pt",
                return_attention_mask=True,
                padding=True,
                truncation=False,
            ).to(self.model_.device)

            if return_next_token_prob_scores or target_token_strs or return_top_target_token_str:
                outputs = self.model_.generate(
                    **inputs,
                    max_new_tokens=1,
                    pad_token_id=self.tokenizer_.pad_token_id,
                    output_logits=True,
                    return_dict_in_generate=True,
                )
                next_token_logits = outputs['logits'][0]
                next_token_probs = next_token_logits.softmax(
                    axis=-1).detach().cpu().numpy()

                if target_token_strs is None:
                    return next_token_probs

                target_token_ids = self._check_target_token_strs(
                    target_token_strs)
                if return_top_target_token_str:
                    selected_tokens = next_token_probs[:, np.array(
                        target_token_ids)].squeeze().argmax(axis=-1)
                    out_strs = [
                        target_token_strs[selected_tokens[i]]
                        for i in range(len(selected_tokens))
                    ]
                    if len(out_strs) == 1:
                        out_strs = out_strs[0]
                    if use_cache:
                        pkl.dump(out_strs, open(cache_file, "wb"))
                    return out_strs
                else:
                    out_dict_list = [
                        {target_token_strs[i]: next_token_probs[prompt_num, target_token_ids[i]]
                            for i in range(len(target_token_strs))
                         }
                        for prompt_num in range(len(prompt))
                    ]
                    return out_dict_list
            else:
                outputs = self.model_.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=do_sample,
                    pad_token_id=self.tokenizer_.pad_token_id,
                )
                # top_p=0.92,
                # temperature=0,
                # top_k=0
            if input_is_str:
                out_str = self.tokenizer_.decode(
                    outputs[0], skip_special_tokens=True)
                # print('out_str', out_str)
                if 'mistral' in self.checkpoint and 'Instruct' in self.checkpoint:
                    out_str = out_str[len(prompt) - 2:]
                elif 'Meta-Llama-3' in self.checkpoint and 'Instruct' in self.checkpoint:
                    out_str = out_str[len(prompt) - 145:]
                else:
                    out_str = out_str[len(prompt):]

                if use_cache:
                    pkl.dump(out_str, open(cache_file, "wb"))
                return out_str
            else:
                out_strs = []
                for i in range(outputs.shape[0]):
                    out_tokens = outputs[i]
                    out_str = self.tokenizer_.decode(
                        out_tokens, skip_special_tokens=True)
                    if 'Ministral' in self.checkpoint and 'Instruct' in self.checkpoint:
                        out_str = out_str[len(prompt[i]) - 16:]
                    elif 'Qwen' in self.checkpoint:
                        out_str = out_str[len(prompt[i]) - 34:]
                    elif 'mistral' in self.checkpoint and 'Instruct' in self.checkpoint:
                        out_str = out_str[len(prompt[i]) - 2:]
                    elif 'Meta-Llama-3' in self.checkpoint and 'Instruct' in self.checkpoint:
                        # print('here')
                        out_str = out_str[len(prompt) + 187:]
                    else:
                        out_str = out_str[len(prompt[i]):]
                    out_strs.append(out_str)
                if use_cache:
                    pkl.dump(out_strs, open(cache_file, "wb"))
                return out_strs

    def _check_target_token_strs(self, target_token_strs, override_token_with_first_token_id=False):
        if isinstance(target_token_strs, str):
            target_token_strs = [target_token_strs]

        target_token_ids = [self.tokenizer_(target_token_str, add_special_tokens=False)["input_ids"]
                            for target_token_str in target_token_strs]

        # Check that the target token is in the vocab
        if override_token_with_first_token_id:
            # Get first token id in target_token_str
            target_token_ids = [target_token_id[0]
                                for target_token_id in target_token_ids]
        else:
            for i in range(len(target_token_strs)):
                if len(target_token_ids[i]) > 1:
                    raise ValueError(
                        f"target_token_str {target_token_strs[i]} has multiple tokens: " +
                        str([self.tokenizer_.decode(target_token_id)
                            for target_token_id in target_token_ids[i]]))
        return target_token_ids
class LLM_HF_Pipeline (checkpoint, CACHE_DIR)
Expand source code
class LLM_HF_Pipeline:
    def __init__(self, checkpoint, CACHE_DIR):

        self.pipeline_ = transformers.pipeline(
            "text-generation",
            model=checkpoint,
            model_kwargs={"torch_dtype": torch.bfloat16},
            # 'device_map': "auto"},
            # model_kwargs={'torch_dtype': torch.float16},
            device_map="auto",
        )
        self.pipeline_.tokenizer.pad_token_id = self.pipeline_.tokenizer.eos_token_id
        self.pipeline_.tokenizer.padding_side = 'left'
        # self.pipeline_.model.generation_config.pad_token_id = self.pipeline_.tokenizer.pad_token_id
        self.cache_dir = CACHE_DIR

    def __call__(
        self,
        prompt: Union[str, List[str]],
        max_new_tokens=20,
        use_cache=True,
        verbose=False,
        batch_size=64,
    ):
        use_cache = use_cache and self.cache_dir is not None
        if use_cache:
            os.makedirs(self.cache_dir, exist_ok=True)
            hash_str = hashlib.sha256(str(prompt).encode()).hexdigest()
            cache_file = join(
                self.cache_dir, f"{hash_str}__num_tok={max_new_tokens}.pkl"
            )

            if os.path.exists(cache_file):
                if verbose:
                    print("cached!")
                try:
                    return pkl.load(open(cache_file, "rb"))
                except:
                    print('failed to load cache so rerunning...')
            if verbose:
                print("not cached...")
        outputs = self.pipeline_(
            prompt,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            do_sample=False,
            pad_token_id=self.pipeline_.tokenizer.pad_token_id,
            top_p=None,
            temperature=None,
        )
        # print(outputs)
        if isinstance(prompt, str):
            texts = outputs[0]["generated_text"][len(prompt):]
        else:
            texts = [outputs[i][0]['generated_text']
                     [len(prompt[i]):] for i in range(len(outputs))]

        if use_cache:
            pkl.dump(texts, open(cache_file, "wb"))
        return texts
class LlamaTokenizer (*args, **kwargs)

Methods

def call(self, *args, **kwargs)
Expand source code
def call(self, *args, **kwargs):
    pass