Module imodelsx.qaemb.vllm_example

Functions

def time_batch_inference_hf(prompts)
Expand source code
def time_batch_inference_hf(prompts):
    # set enable_prefix_caching=True to enable APC
    t0 = time.time()
    llm = imodelsx.llm.get_llm(
        checkpoint='meta-llama/Meta-Llama-3-8B-Instruct',
    )
    t1 = time.time()
    print(f'LLM created in {t1 - t0:.2f} seconds')
    outputs = llm(prompts, use_cache=False)
    t2 = time.time()
    print(f'Inference completed in {t2 - t1:.2f} seconds')
def time_batch_inference_vllm(prompts)
Expand source code
def time_batch_inference_vllm(prompts):
    import ray
    from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
    ray.init(log_to_driver=False)
    ray.data.DataContext.get_current().enable_progress_bars = False
    ds = ray.data.from_items(
        [{"text": prompt} for prompt in prompts],
        # parallelism=1,
        # ray_remote_args={"num_cpus": 1, "num_gpus": 0},
    )

    t0 = time.time()
    config = vLLMEngineProcessorConfig(
        model_source='meta-llama/Meta-Llama-3-8B-Instruct',
        engine_kwargs={
            "enable_chunked_prefill": True,
            "max_num_batched_tokens": 4096,
            "max_model_len": 8192,
        },
        concurrency=1,
        batch_size=64,
        tensor_parallel_size=4,
    )
    vllm_processor = build_llm_processor(
        config,
        preprocess=lambda row: dict(
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": row["text"]},
            ],
            sampling_params=dict(
                temperature=0,
                max_tokens=1,
            ),
        ),
        postprocess=lambda row: dict(
            answer=row["generated_text"],
            **row,  # This will return all the original columns in the dataset.
        ),
    )
    t1 = time.time()
    print(f'vLLM processor created in {t1 - t0:.2f} seconds')
    
    ds = vllm_processor(ds)
    outputs = ds.take(limit=10)
    t2 = time.time()
    print(f'Inference completed in {t2 - t1:.2f} seconds')
    ray.shutdown()