Expand source code
def time_batch_inference_vllm(prompts):
import ray
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
ray.init(log_to_driver=False)
ray.data.DataContext.get_current().enable_progress_bars = False
ds = ray.data.from_items(
[{"text": prompt} for prompt in prompts],
# parallelism=1,
# ray_remote_args={"num_cpus": 1, "num_gpus": 0},
)
t0 = time.time()
config = vLLMEngineProcessorConfig(
model_source='meta-llama/Meta-Llama-3-8B-Instruct',
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 8192,
},
concurrency=1,
batch_size=64,
tensor_parallel_size=4,
)
vllm_processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": row["text"]},
],
sampling_params=dict(
temperature=0,
max_tokens=1,
),
),
postprocess=lambda row: dict(
answer=row["generated_text"],
**row, # This will return all the original columns in the dataset.
),
)
t1 = time.time()
print(f'vLLM processor created in {t1 - t0:.2f} seconds')
ds = vllm_processor(ds)
outputs = ds.take(limit=10)
t2 = time.time()
print(f'Inference completed in {t2 - t1:.2f} seconds')
ray.shutdown()