Sync from v0.13
This commit is contained in:
111
examples/offline_inference/async_llm_streaming.py
Normal file
111
examples/offline_inference/async_llm_streaming.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Simple example demonstrating streaming offline inference with AsyncLLM (V1 engine).
|
||||
|
||||
This script shows the core functionality of vLLM's AsyncLLM engine for streaming
|
||||
token-by-token output in offline inference scenarios. It demonstrates DELTA mode
|
||||
streaming where you receive new tokens as they are generated.
|
||||
|
||||
Usage:
|
||||
python examples/offline_inference/async_llm_streaming.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
|
||||
async def stream_response(engine: AsyncLLM, prompt: str, request_id: str) -> None:
|
||||
"""
|
||||
Stream response from AsyncLLM and display tokens as they arrive.
|
||||
|
||||
This function demonstrates the core streaming pattern:
|
||||
1. Create SamplingParams with DELTA output kind
|
||||
2. Call engine.generate() and iterate over the async generator
|
||||
3. Print new tokens as they arrive
|
||||
4. Handle the finished flag to know when generation is complete
|
||||
"""
|
||||
print(f"\n🚀 Prompt: {prompt!r}")
|
||||
print("💬 Response: ", end="", flush=True)
|
||||
|
||||
# Configure sampling parameters for streaming
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
seed=42, # For reproducible results
|
||||
output_kind=RequestOutputKind.DELTA, # Get only new tokens each iteration
|
||||
)
|
||||
|
||||
try:
|
||||
# Stream tokens from AsyncLLM
|
||||
async for output in engine.generate(
|
||||
request_id=request_id, prompt=prompt, sampling_params=sampling_params
|
||||
):
|
||||
# Process each completion in the output
|
||||
for completion in output.outputs:
|
||||
# In DELTA mode, we get only new tokens generated since last iteration
|
||||
new_text = completion.text
|
||||
if new_text:
|
||||
print(new_text, end="", flush=True)
|
||||
|
||||
# Check if generation is finished
|
||||
if output.finished:
|
||||
print("\n✅ Generation complete!")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error during streaming: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def main():
|
||||
print("🔧 Initializing AsyncLLM...")
|
||||
|
||||
# Create AsyncLLM engine with simple configuration
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True, # Faster startup for examples
|
||||
)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
|
||||
try:
|
||||
# Example prompts to demonstrate streaming
|
||||
prompts = [
|
||||
"The future of artificial intelligence is",
|
||||
"In a galaxy far, far away",
|
||||
"The key to happiness is",
|
||||
]
|
||||
|
||||
print(f"🎯 Running {len(prompts)} streaming examples...")
|
||||
|
||||
# Process each prompt
|
||||
for i, prompt in enumerate(prompts, 1):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Example {i}/{len(prompts)}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
request_id = f"stream-example-{i}"
|
||||
await stream_response(engine, prompt, request_id)
|
||||
|
||||
# Brief pause between examples
|
||||
if i < len(prompts):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
print("\n🎉 All streaming examples completed!")
|
||||
|
||||
finally:
|
||||
# Always clean up the engine
|
||||
print("🔧 Shutting down engine...")
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Interrupted by user")
|
||||
540
examples/offline_inference/audio_language.py
Executable file
540
examples/offline_inference/audio_language.py
Executable file
@@ -0,0 +1,540 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on audio language models.
|
||||
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||
question_per_audio_count = {
|
||||
0: "What is 1+1?",
|
||||
1: "What is recited in the audio?",
|
||||
2: "What sport and what nursery rhyme are referenced?",
|
||||
}
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str | None = None
|
||||
prompt_token_ids: dict[str, list[int]] | None = None
|
||||
multi_modal_data: dict[str, Any] | None = None
|
||||
stop_token_ids: list[int] | None = None
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# AudioFlamingo3
|
||||
def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "nvidia/audio-flamingo-3-hf"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
# AudioFlamingo3 uses <sound> token for audio
|
||||
audio_placeholder = "<sound>" * audio_count
|
||||
|
||||
prompt = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_placeholder}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Gemma3N
|
||||
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "google/gemma-3n-E2B-it"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_batched_tokens=2048,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
enforce_eager=True,
|
||||
)
|
||||
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Granite Speech
|
||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# NOTE - the setting in this example are somewhat different from what is
|
||||
# optimal for granite speech, and it is generally recommended to use beam
|
||||
# search. Check the model README for suggested settings.
|
||||
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
|
||||
model_name = "ibm-granite/granite-speech-3.3-8b"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
# The model has an audio-specific lora directly in its model dir;
|
||||
# it should be enabled whenever you pass audio inputs to the model.
|
||||
speech_lora_path = model_name
|
||||
audio_placeholder = "<|audio|>" * audio_count
|
||||
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompts,
|
||||
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
# MiDashengLM
|
||||
def run_midashenglm(question: str, audio_count: int):
|
||||
model_name = "mispeech/midashenglm-7b"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
audio_in_prompt = "".join(
|
||||
["<|audio_bos|><|AUDIO|><|audio_eos|>" for idx in range(audio_count)]
|
||||
)
|
||||
|
||||
default_system = "You are a helpful language and speech assistant."
|
||||
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# MiniCPM-O
|
||||
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "openbmb/MiniCPM-o-2_6"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
|
||||
|
||||
audio_placeholder = "(<audio>./</audio>)" * audio_count
|
||||
audio_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}" # noqa: E501
|
||||
messages = [{"role": "user", "content": f"{audio_placeholder}\n{question}"}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
chat_template=audio_chat_template,
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=stop_token_ids,
|
||||
)
|
||||
|
||||
|
||||
# Phi-4-multimodal-instruct
|
||||
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process audio inputs.
|
||||
"""
|
||||
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
speech_lora_path = os.path.join(model_path, "speech-lora")
|
||||
placeholders = "".join([f"<|audio_{i + 1}|>" for i in range(audio_count)])
|
||||
|
||||
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
trust_remote_code=True,
|
||||
max_model_len=12800,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompts,
|
||||
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
def run_phi4_multimodal(question: str, audio_count: int) -> ModelRequestData:
|
||||
"""
|
||||
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
|
||||
show how to process audio inputs.
|
||||
"""
|
||||
model_path = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
|
||||
)
|
||||
# Since the vision-lora and speech-lora co-exist with the base model,
|
||||
# we have to manually specify the path of the lora weights.
|
||||
speech_lora_path = os.path.join(model_path, "speech-lora")
|
||||
placeholders = "<|audio|>" * audio_count
|
||||
|
||||
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_path,
|
||||
max_model_len=12800,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompts,
|
||||
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
|
||||
)
|
||||
|
||||
|
||||
# Qwen2-Audio
|
||||
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
audio_in_prompt = "".join(
|
||||
[
|
||||
f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
|
||||
for idx in range(audio_count)
|
||||
]
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Qwen2.5-Omni
|
||||
def run_qwen2_5_omni(question: str, audio_count: int):
|
||||
model_name = "Qwen/Qwen2.5-Omni-7B"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
audio_in_prompt = "".join(
|
||||
["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)]
|
||||
)
|
||||
|
||||
default_system = (
|
||||
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||
"generating text and speech."
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Ultravox 0.5-1B
|
||||
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{"role": "user", "content": "<|audio|>\n" * audio_count + question}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Voxtral
|
||||
# Make sure to install mistral-common[audio].
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
# Whisper
|
||||
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
assert audio_count == 1, "Whisper only support single audio input per prompt"
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
|
||||
prompt = "<|startoftranscript|>"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=448,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"audioflamingo3": run_audioflamingo3,
|
||||
"gemma3n": run_gemma3n,
|
||||
"granite_speech": run_granite_speech,
|
||||
"midashenglm": run_midashenglm,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
"phi4_multimodal": run_phi4_multimodal,
|
||||
"qwen2_audio": run_qwen2_audio,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"ultravox": run_ultravox,
|
||||
"voxtral": run_voxtral,
|
||||
"whisper": run_whisper,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"audio language models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
"-m",
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1, help="Number of prompts to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
"-tp",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Tensor parallel size to override the model's default setting. ",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1:
|
||||
raise ValueError(
|
||||
f"tensor_parallel_size must be a positive integer, "
|
||||
f"got {args.tensor_parallel_size}"
|
||||
)
|
||||
|
||||
audio_count = args.num_audios
|
||||
req_data = model_example_map[model](
|
||||
question_per_audio_count[audio_count], audio_count
|
||||
)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
if args.tensor_parallel_size is not None:
|
||||
engine_args["tensor_parallel_size"] = args.tensor_parallel_size
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
mm_data = req_data.multi_modal_data
|
||||
if not mm_data:
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
|
||||
assert args.num_prompts > 0
|
||||
inputs = {"multi_modal_data": mm_data}
|
||||
|
||||
if req_data.prompt:
|
||||
inputs["prompt"] = req_data.prompt
|
||||
else:
|
||||
inputs["prompt_token_ids"] = req_data.prompt_token_ids
|
||||
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
inputs = [inputs] * args.num_prompts
|
||||
# Add LoRA request if applicable
|
||||
lora_request = (
|
||||
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
103
examples/offline_inference/automatic_prefix_caching.py
Normal file
103
examples/offline_inference/automatic_prefix_caching.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstration script for Automatic Prefix Caching (APC) in vLLM.
|
||||
|
||||
Automatic Prefix Caching (APC) allows the vLLM engine to reuse cached
|
||||
KV (key-value) pairs from previous prompts if a new query shares the same
|
||||
prefix. This reduces redundant computation and improves inference speed.
|
||||
|
||||
To enable APC, set `enable_prefix_caching=True` when initializing the
|
||||
vLLM engine.
|
||||
|
||||
This script uses a long Markdown table as the shared prompt prefix and
|
||||
compares the generation time for two queries that share the same prefix
|
||||
but ask different questions.
|
||||
|
||||
Run:
|
||||
python examples/offline_inference/automatic_prefix_caching.py
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# ruff: noqa: E501
|
||||
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
|
||||
LONG_PROMPT = (
|
||||
"You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n"
|
||||
+ """
|
||||
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|
||||
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
|
||||
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
|
||||
| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON |
|
||||
| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK |
|
||||
| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW |
|
||||
| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ |
|
||||
| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE |
|
||||
| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY |
|
||||
| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC |
|
||||
| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK |
|
||||
| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC|
|
||||
| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ |
|
||||
| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE |
|
||||
| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA |
|
||||
| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB |
|
||||
| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK |
|
||||
| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD |
|
||||
| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ |
|
||||
| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE |
|
||||
| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA |
|
||||
| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON |
|
||||
| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK |
|
||||
| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA |
|
||||
| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ|
|
||||
| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE |
|
||||
| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO |
|
||||
| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC |
|
||||
| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK |
|
||||
| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA |
|
||||
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
|
||||
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def get_generation_time(llm, sampling_params, prompts):
|
||||
# time the generation
|
||||
start_time = time.time()
|
||||
output = llm.generate(prompts, sampling_params=sampling_params)
|
||||
end_time = time.time()
|
||||
# print the output and generation time
|
||||
print("-" * 30)
|
||||
print(f"Output: {output[0].outputs[0].text}")
|
||||
print(f"Generation time: {end_time - start_time} seconds.")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def main():
|
||||
# set enable_prefix_caching=True to enable APC
|
||||
llm = LLM(model="lmsys/longchat-13b-16k", enable_prefix_caching=True)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||
|
||||
# Querying the age of John Doe
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT
|
||||
+ "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
|
||||
)
|
||||
|
||||
# Querying the age of Zack Blue
|
||||
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT
|
||||
+ "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
80
examples/offline_inference/basic/README.md
Normal file
80
examples/offline_inference/basic/README.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# Basic
|
||||
|
||||
The `LLM` class provides the primary Python interface for doing offline inference, which is interacting with a model without using a separate model inference server.
|
||||
|
||||
## Usage
|
||||
|
||||
The first script in this example shows the most basic usage of vLLM. If you are new to Python and vLLM, you should start here.
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/basic/basic.py
|
||||
```
|
||||
|
||||
The rest of the scripts include an [argument parser](https://docs.python.org/3/library/argparse.html), which you can use to pass any arguments that are compatible with [`LLM`](https://docs.vllm.ai/en/latest/api/offline_inference/llm.html). Try running the script with `--help` for a list of all available arguments.
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/basic/classify.py
|
||||
```
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/basic/embed.py
|
||||
```
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/basic/score.py
|
||||
```
|
||||
|
||||
The chat and generate scripts also accept the [sampling parameters](https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-parameters): `max_tokens`, `temperature`, `top_p` and `top_k`.
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/basic/chat.py
|
||||
```
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/basic/generate.py
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
In the scripts that support passing arguments, you can experiment with the following features.
|
||||
|
||||
### Default generation config
|
||||
|
||||
The `--generation-config` argument specifies where the generation config will be loaded from when calling `LLM.get_default_sampling_params()`. If set to ‘auto’, the generation config will be loaded from model path. If set to a folder path, the generation config will be loaded from the specified folder path. If it is not provided, vLLM defaults will be used.
|
||||
|
||||
> If max_new_tokens is specified in generation config, then it sets a server-wide limit on the number of output tokens for all requests.
|
||||
|
||||
Try it yourself with the following argument:
|
||||
|
||||
```bash
|
||||
--generation-config auto
|
||||
```
|
||||
|
||||
### Quantization
|
||||
|
||||
#### GGUF
|
||||
|
||||
vLLM supports models that are quantized using GGUF.
|
||||
|
||||
Try one yourself by downloading a quantized GGUF model and using the following arguments:
|
||||
|
||||
```python
|
||||
from huggingface_hub import hf_hub_download
|
||||
repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
|
||||
filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
|
||||
print(hf_hub_download(repo_id, filename=filename))
|
||||
```
|
||||
|
||||
```bash
|
||||
--model {local-path-printed-above} --tokenizer microsoft/Phi-3-medium-4k-instruct
|
||||
```
|
||||
|
||||
### CPU offload
|
||||
|
||||
The `--cpu-offload-gb` argument can be seen as a virtual way to increase the GPU memory size. For example, if you have one 24 GB GPU and set this to 10, virtually you can think of it as a 34 GB GPU. Then you can load a 13B model with BF16 weight, which requires at least 26GB GPU memory. Note that this requires fast CPU-GPU interconnect, as part of the model is loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||
|
||||
Try it yourself with the following arguments:
|
||||
|
||||
```bash
|
||||
--model meta-llama/Llama-2-13b-chat-hf --cpu-offload-gb 10
|
||||
```
|
||||
35
examples/offline_inference/basic/basic.py
Normal file
35
examples/offline_inference/basic/basic.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
96
examples/offline_inference/basic/chat.py
Normal file
96
examples/offline_inference/basic/chat.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
# Add example params
|
||||
parser.add_argument("--chat-template-path", type=str)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: dict):
|
||||
# Pop arguments not used by LLM
|
||||
max_tokens = args.pop("max_tokens")
|
||||
temperature = args.pop("temperature")
|
||||
top_p = args.pop("top_p")
|
||||
top_k = args.pop("top_k")
|
||||
chat_template_path = args.pop("chat_template_path")
|
||||
|
||||
# Create an LLM
|
||||
llm = LLM(**args)
|
||||
|
||||
# Create sampling params object
|
||||
sampling_params = llm.get_default_sampling_params()
|
||||
if max_tokens is not None:
|
||||
sampling_params.max_tokens = max_tokens
|
||||
if temperature is not None:
|
||||
sampling_params.temperature = temperature
|
||||
if top_p is not None:
|
||||
sampling_params.top_p = top_p
|
||||
if top_k is not None:
|
||||
sampling_params.top_k = top_k
|
||||
|
||||
def print_outputs(outputs):
|
||||
print("\nGenerated Outputs:\n" + "-" * 80)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print("-" * 80)
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
# In this script, we demonstrate how to pass input to the chat method:
|
||||
conversation = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write an essay about the importance of higher education.",
|
||||
},
|
||||
]
|
||||
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
|
||||
print_outputs(outputs)
|
||||
|
||||
# You can run batch inference with llm.chat API
|
||||
conversations = [conversation for _ in range(10)]
|
||||
|
||||
# We turn on tqdm progress bar to verify it's indeed running batch inference
|
||||
outputs = llm.chat(conversations, sampling_params, use_tqdm=True)
|
||||
print_outputs(outputs)
|
||||
|
||||
# A chat template can be optionally supplied.
|
||||
# If not, the model will use its default chat template.
|
||||
if chat_template_path is not None:
|
||||
with open(chat_template_path) as f:
|
||||
chat_template = f.read()
|
||||
|
||||
outputs = llm.chat(
|
||||
conversations,
|
||||
sampling_params,
|
||||
use_tqdm=False,
|
||||
chat_template=chat_template,
|
||||
)
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_parser()
|
||||
args: dict = vars(parser.parse_args())
|
||||
main(args)
|
||||
52
examples/offline_inference/basic/classify.py
Normal file
52
examples/offline_inference/basic/classify.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="jason9693/Qwen2.5-1.5B-apeach",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for classification models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate logits. The output is a list of ClassificationRequestOutputs.
|
||||
outputs = llm.classify(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
probs = output.outputs.probs
|
||||
probs_trimmed = (str(probs[:16])[:-1] + ", ...]") if len(probs) > 16 else probs
|
||||
print(
|
||||
f"Prompt: {prompt!r} \n"
|
||||
f"Class Probabilities: {probs_trimmed} (size={len(probs)})"
|
||||
)
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
59
examples/offline_inference/basic/embed.py
Normal file
59
examples/offline_inference/basic/embed.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="intfloat/e5-small",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for embedding models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = llm.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = (
|
||||
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
|
||||
)
|
||||
print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
65
examples/offline_inference/basic/generate.py
Normal file
65
examples/offline_inference/basic/generate.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: dict):
|
||||
# Pop arguments not used by LLM
|
||||
max_tokens = args.pop("max_tokens")
|
||||
temperature = args.pop("temperature")
|
||||
top_p = args.pop("top_p")
|
||||
top_k = args.pop("top_k")
|
||||
|
||||
# Create an LLM
|
||||
llm = LLM(**args)
|
||||
|
||||
# Create a sampling params object
|
||||
sampling_params = llm.get_default_sampling_params()
|
||||
if max_tokens is not None:
|
||||
sampling_params.max_tokens = max_tokens
|
||||
if temperature is not None:
|
||||
sampling_params.temperature = temperature
|
||||
if top_p is not None:
|
||||
sampling_params.top_p = top_p
|
||||
if top_k is not None:
|
||||
sampling_params.top_k = top_k
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||
# objects that contain the prompt, generated text, and other information.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_parser()
|
||||
args: dict = vars(parser.parse_args())
|
||||
main(args)
|
||||
53
examples/offline_inference/basic/reward.py
Normal file
53
examples/offline_inference/basic/reward.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="internlm/internlm2-1_8b-reward",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for reward models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate rewards. The output is a list of PoolingRequestOutput.
|
||||
outputs = llm.reward(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
rewards = output.outputs.data
|
||||
rewards_trimmed = (
|
||||
(str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards
|
||||
)
|
||||
print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
55
examples/offline_inference/basic/score.py
Normal file
55
examples/offline_inference/basic/score.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import AttentionConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if current_platform.is_rocm():
|
||||
args.attention_config = AttentionConfig(
|
||||
backend=AttentionBackendEnum.FLEX_ATTENTION
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for cross-encoder models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate scores. The output is a list of ScoringRequestOutputs.
|
||||
outputs = llm.score(text_1, texts_2)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for text_2, output in zip(texts_2, outputs):
|
||||
score = output.outputs.score
|
||||
print(f"Pair: {[text_1, text_2]!r} \nScore: {score}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
93
examples/offline_inference/batch_llm_inference.py
Normal file
93
examples/offline_inference/batch_llm_inference.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use Ray Data for data parallel batch inference.
|
||||
|
||||
Ray Data is a data processing framework that can process very large datasets
|
||||
with first-class support for vLLM.
|
||||
|
||||
Ray Data provides functionality for:
|
||||
* Reading and writing to most popular file formats and cloud object storage.
|
||||
* Streaming execution, so you can run inference on datasets that far exceed
|
||||
the aggregate RAM of the cluster.
|
||||
* Scale up the workload without code changes.
|
||||
* Automatic sharding, load-balancing, and autoscaling across a Ray cluster,
|
||||
with built-in fault-tolerance and retry semantics.
|
||||
* Continuous batching that keeps vLLM replicas saturated and maximizes GPU
|
||||
utilization.
|
||||
* Compatible with tensor/pipeline parallel inference.
|
||||
|
||||
Learn more about Ray Data's LLM integration:
|
||||
https://docs.ray.io/en/latest/data/working-with-llms.html
|
||||
"""
|
||||
|
||||
import ray
|
||||
from packaging.version import Version
|
||||
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
|
||||
|
||||
assert Version(ray.__version__) >= Version("2.44.1"), (
|
||||
"Ray version must be at least 2.44.1"
|
||||
)
|
||||
|
||||
# Uncomment to reduce clutter in stdout
|
||||
# ray.init(log_to_driver=False)
|
||||
# ray.data.DataContext.get_current().enable_progress_bars = False
|
||||
|
||||
# Read one text file from S3. Ray Data supports reading multiple files
|
||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||||
print(ds.schema())
|
||||
|
||||
size = ds.count()
|
||||
print(f"Size of dataset: {size} prompts")
|
||||
|
||||
# Configure vLLM engine.
|
||||
config = vLLMEngineProcessorConfig(
|
||||
model_source="unsloth/Llama-3.1-8B-Instruct",
|
||||
engine_kwargs={
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4096,
|
||||
"max_model_len": 16384,
|
||||
},
|
||||
concurrency=1, # set the number of parallel vLLM replicas
|
||||
batch_size=64,
|
||||
)
|
||||
|
||||
# Create a Processor object, which will be used to
|
||||
# do batch inference on the dataset
|
||||
vllm_processor = build_llm_processor(
|
||||
config,
|
||||
preprocess=lambda row: dict(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a bot that responds with haikus."},
|
||||
{"role": "user", "content": row["text"]},
|
||||
],
|
||||
sampling_params=dict(
|
||||
temperature=0.3,
|
||||
max_tokens=250,
|
||||
),
|
||||
),
|
||||
postprocess=lambda row: dict(
|
||||
answer=row["generated_text"],
|
||||
**row, # This will return all the original columns in the dataset.
|
||||
),
|
||||
)
|
||||
|
||||
ds = vllm_processor(ds)
|
||||
|
||||
# Peek first 10 results.
|
||||
# NOTE: This is for local testing and debugging. For production use case,
|
||||
# one should write full result out as shown below.
|
||||
outputs = ds.take(limit=10)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output["prompt"]
|
||||
generated_text = output["generated_text"]
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
|
||||
# Write inference output data out as Parquet files to S3.
|
||||
# Multiple files would be written to the output destination,
|
||||
# and each task would write one or more files separately.
|
||||
#
|
||||
# ds.write_parquet("s3://<your-output-bucket>")
|
||||
147
examples/offline_inference/chat_with_tools.py
Normal file
147
examples/offline_inference/chat_with_tools.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
# This script is an offline demo for function calling
|
||||
#
|
||||
# If you want to run a server/client setup, please follow this code:
|
||||
#
|
||||
# - Server:
|
||||
#
|
||||
# ```bash
|
||||
# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral
|
||||
# ```
|
||||
#
|
||||
# - Client:
|
||||
#
|
||||
# ```bash
|
||||
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
|
||||
# --header 'Content-Type: application/json' \
|
||||
# --header 'Authorization: Bearer token' \
|
||||
# --data '{
|
||||
# "model": "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type" : "text", "text": "Describe this image in detail please."},
|
||||
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
|
||||
# {"type" : "text", "text": "and this one as well. Answer in French."},
|
||||
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
# }'
|
||||
# ```
|
||||
#
|
||||
# Usage:
|
||||
# python demo.py simple
|
||||
# python demo.py advanced
|
||||
|
||||
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# or switch to "mistralai/Mistral-Nemo-Instruct-2407"
|
||||
# or "mistralai/Mistral-Large-Instruct-2407"
|
||||
# or any other mistral model with function calling ability
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tokenizer_mode="mistral",
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
)
|
||||
|
||||
|
||||
def generate_random_id(length=9):
|
||||
characters = string.ascii_letters + string.digits
|
||||
random_id = "".join(random.choice(characters) for _ in range(length))
|
||||
return random_id
|
||||
|
||||
|
||||
# simulate an API that can be called
|
||||
def get_current_weather(city: str, state: str, unit: "str"):
|
||||
return (
|
||||
f"The weather in {city}, {state} is 85 degrees {unit}. It is "
|
||||
"partly cloudly, with highs in the 90's."
|
||||
)
|
||||
|
||||
|
||||
tool_functions = {"get_current_weather": get_current_weather}
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "the two-letter abbreviation for the state that the city is"
|
||||
" in, e.g. 'CA' which would mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?",
|
||||
}
|
||||
]
|
||||
|
||||
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
|
||||
output = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# append the assistant message
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": output,
|
||||
}
|
||||
)
|
||||
|
||||
# let's now actually parse and execute the model's output simulating an API call by using the
|
||||
# above defined function
|
||||
tool_calls = json.loads(output)
|
||||
tool_answers = [
|
||||
tool_functions[call["name"]](**call["arguments"]) for call in tool_calls
|
||||
]
|
||||
|
||||
# append the answer as a tool message and let the LLM give you an answer
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "\n\n".join(tool_answers),
|
||||
"tool_call_id": generate_random_id(),
|
||||
}
|
||||
)
|
||||
|
||||
outputs = llm.chat(messages, sampling_params, tools=tools)
|
||||
|
||||
print(outputs[0].outputs[0].text.strip())
|
||||
# yields
|
||||
# 'The weather in Dallas, TX is 85 degrees Fahrenheit. '
|
||||
# 'It is partly cloudly, with highs in the 90's.'
|
||||
68
examples/offline_inference/context_extension.py
Normal file
68
examples/offline_inference/context_extension.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This script demonstrates how to extend the context length
|
||||
of a Qwen model using the YARN method (rope_parameters)
|
||||
and run a simple chat example.
|
||||
|
||||
Usage:
|
||||
python examples/offline_inference/context_extension.py
|
||||
"""
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def create_llm():
|
||||
rope_theta = 1000000
|
||||
original_max_position_embeddings = 32768
|
||||
factor = 4.0
|
||||
|
||||
# Use yarn to extend context
|
||||
hf_overrides = {
|
||||
"rope_parameters": {
|
||||
"rope_theta": rope_theta,
|
||||
"rope_type": "yarn",
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": original_max_position_embeddings,
|
||||
},
|
||||
"max_model_len": int(original_max_position_embeddings * factor),
|
||||
}
|
||||
|
||||
llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)
|
||||
return llm
|
||||
|
||||
|
||||
def run_llm_chat(llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=128,
|
||||
)
|
||||
|
||||
conversation = [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||
]
|
||||
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
|
||||
return outputs
|
||||
|
||||
|
||||
def print_outputs(outputs):
|
||||
print("\nGenerated Outputs:\n" + "-" * 80)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\n")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def main():
|
||||
llm = create_llm()
|
||||
outputs = run_llm_chat(llm)
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
268
examples/offline_inference/data_parallel.py
Normal file
268
examples/offline_inference/data_parallel.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Usage:
|
||||
Single node:
|
||||
python examples/offline_inference/data_parallel.py \
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
--dp-size=2 \
|
||||
--tp-size=2
|
||||
|
||||
Multi-node:
|
||||
Node 0 (assume the node has ip of 10.99.48.128):
|
||||
python examples/offline_inference/data_parallel.py \
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
--dp-size=2 \
|
||||
--tp-size=2 \
|
||||
--node-size=2 \
|
||||
--node-rank=0 \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
Node 1:
|
||||
python examples/offline_inference/data_parallel.py \
|
||||
--model="ibm-research/PowerMoE-3b" \
|
||||
--dp-size=2 \
|
||||
--tp-size=2 \
|
||||
--node-size=2 \
|
||||
--node-rank=1 \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
"""
|
||||
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path",
|
||||
)
|
||||
parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
|
||||
parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size")
|
||||
parser.add_argument(
|
||||
"--node-size", type=int, default=1, help="Total number of nodes"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--node-rank", type=int, default=0, help="Rank of the current node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--master-addr", type=str, default="", help="Master node IP address"
|
||||
)
|
||||
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
|
||||
parser.add_argument(
|
||||
"--enforce-eager", action="store_true", help="Enforce eager mode execution."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code", action="store_true", help="Trust remote code."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-num-seqs",
|
||||
type=int,
|
||||
default=64,
|
||||
help=("Maximum number of sequences to be processed in a single iteration."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
help=("Maximum number of tokens to be processed in a single iteration."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=300,
|
||||
help=("Number of seconds before unresponsive process is killed."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-dbo",
|
||||
action="store_true",
|
||||
help=("Enable microbatched execution"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compilation-config",
|
||||
type=int,
|
||||
help=("Compilation optimization (O) mode 0-3."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-expert-parallel",
|
||||
dest="enable_expert_parallel",
|
||||
action="store_false",
|
||||
help="Disable expert parallel (default: enabled).",
|
||||
)
|
||||
parser.set_defaults(enable_expert_parallel=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(
|
||||
model,
|
||||
dp_size,
|
||||
local_dp_rank,
|
||||
global_dp_rank,
|
||||
dp_master_ip,
|
||||
dp_master_port,
|
||||
GPUs_per_dp_rank,
|
||||
enforce_eager,
|
||||
enable_expert_parallel,
|
||||
trust_remote_code,
|
||||
max_num_seqs,
|
||||
max_model_len,
|
||||
compilation_config,
|
||||
gpu_memory_utilization,
|
||||
enable_dbo,
|
||||
quantization,
|
||||
):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
|
||||
|
||||
# CUDA_VISIBLE_DEVICES for each DP rank is set automatically inside the
|
||||
# engine processes.
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 100
|
||||
|
||||
# with DP, each rank should process different prompts.
|
||||
# usually all the DP ranks process a full dataset,
|
||||
# and each rank processes a different part of the dataset.
|
||||
floor = len(prompts) // dp_size
|
||||
remainder = len(prompts) % dp_size
|
||||
|
||||
# Distribute prompts into even groups.
|
||||
def start(rank):
|
||||
return rank * floor + min(rank, remainder)
|
||||
|
||||
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
|
||||
if len(prompts) == 0:
|
||||
# if any rank has no prompts to process,
|
||||
# we need to set a placeholder prompt
|
||||
prompts = ["Placeholder"]
|
||||
print(f"DP rank {global_dp_rank} needs to process {len(prompts)} prompts")
|
||||
|
||||
# Create a sampling params object.
|
||||
# since we are doing data parallel, every rank can have different
|
||||
# sampling params. here we set different max_tokens for different
|
||||
# ranks for demonstration.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2]
|
||||
)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tensor_parallel_size=GPUs_per_dp_rank,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_expert_parallel=enable_expert_parallel,
|
||||
trust_remote_code=trust_remote_code,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enable_dbo=enable_dbo,
|
||||
quantization=quantization,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for i, output in enumerate(outputs):
|
||||
if i >= 5:
|
||||
# print only 5 outputs
|
||||
break
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(
|
||||
f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}"
|
||||
)
|
||||
|
||||
# Give engines time to pause their processing loops before exiting.
|
||||
sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
dp_size = args.dp_size
|
||||
tp_size = args.tp_size
|
||||
node_size = args.node_size
|
||||
node_rank = args.node_rank
|
||||
|
||||
if node_size == 1:
|
||||
dp_master_ip = "127.0.0.1"
|
||||
dp_master_port = get_open_port()
|
||||
else:
|
||||
dp_master_ip = args.master_addr
|
||||
dp_master_port = args.master_port
|
||||
|
||||
assert dp_size % node_size == 0, "dp_size should be divisible by node_size"
|
||||
dp_per_node = dp_size // node_size
|
||||
|
||||
from multiprocessing import Process
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from multiprocessing import set_start_method
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
procs = []
|
||||
for local_dp_rank, global_dp_rank in enumerate(
|
||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)
|
||||
):
|
||||
proc = Process(
|
||||
target=main,
|
||||
args=(
|
||||
args.model,
|
||||
dp_size,
|
||||
local_dp_rank,
|
||||
global_dp_rank,
|
||||
dp_master_ip,
|
||||
dp_master_port,
|
||||
tp_size,
|
||||
args.enforce_eager,
|
||||
args.enable_expert_parallel,
|
||||
args.trust_remote_code,
|
||||
args.max_num_seqs,
|
||||
args.max_model_len,
|
||||
args.compilation_config,
|
||||
args.gpu_memory_utilization,
|
||||
args.enable_dbo,
|
||||
args.quantization,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
for proc in procs:
|
||||
proc.join(timeout=args.timeout)
|
||||
if proc.exitcode is None:
|
||||
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
|
||||
proc.kill()
|
||||
exit_code = 1
|
||||
elif proc.exitcode:
|
||||
exit_code = proc.exitcode
|
||||
|
||||
exit(exit_code)
|
||||
@@ -0,0 +1,10 @@
|
||||
# Disaggregated Prefill V1
|
||||
|
||||
This example contains scripts that demonstrate disaggregated prefill in the offline setting of vLLM.
|
||||
|
||||
## Files
|
||||
|
||||
- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
|
||||
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
|
||||
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
|
||||
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.
|
||||
@@ -0,0 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
|
||||
def read_prompts():
|
||||
"""Read prompts from output.txt"""
|
||||
prompts = []
|
||||
try:
|
||||
with open("output.txt") as f:
|
||||
for line in f:
|
||||
prompts.append(line.strip())
|
||||
print(f"Loaded {len(prompts)} prompts from output.txt")
|
||||
return prompts
|
||||
except FileNotFoundError:
|
||||
print("Error: output.txt file not found")
|
||||
exit(-1)
|
||||
|
||||
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="ExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
),
|
||||
) # , max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
print("-" * 30)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
|
||||
def read_prompts():
|
||||
context = "Hi " * 1000
|
||||
context2 = "Hey " * 500
|
||||
return [
|
||||
context + "Hello, my name is",
|
||||
context + "The capital of France is",
|
||||
context2 + "Your name is",
|
||||
context2 + "The capital of China is",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="ExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
),
|
||||
) # , max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
new_prompts = []
|
||||
print("-" * 30)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
new_prompts.append(prompt + generated_text)
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 30)
|
||||
|
||||
# Write new_prompts to output.txt
|
||||
with open("output.txt", "w") as f:
|
||||
for prompt in new_prompts:
|
||||
f.write(prompt + "\n")
|
||||
print(f"Saved {len(new_prompts)} prompts to output.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
examples/offline_inference/disaggregated-prefill-v1/run.sh
Normal file
11
examples/offline_inference/disaggregated-prefill-v1/run.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
rm -rf local_storage/
|
||||
|
||||
if [ -f "output.txt" ]; then
|
||||
rm output.txt
|
||||
fi
|
||||
|
||||
# The directory of current script
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 "$SCRIPT_DIR/prefill_example.py"
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 "$SCRIPT_DIR/decode_example.py"
|
||||
127
examples/offline_inference/disaggregated_prefill.py
Normal file
127
examples/offline_inference/disaggregated_prefill.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates the example usage of disaggregated prefilling
|
||||
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
|
||||
and then transfer the KV cache between them.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
|
||||
def run_prefill(prefill_done):
|
||||
# We use GPU 0 for prefill node.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
# The prefill node receives two requests, while the decode node receives
|
||||
# three requests. So the decode node will only receive the KV Cache for
|
||||
# requests 1 and 3. The decode node will use the KV Cache of requests 1
|
||||
# and 3 and do prefilling on request 2.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"Hi, your name is",
|
||||
# The decode node will actually "prefill" this request.
|
||||
"Tell me a very long story",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
# Using P2pNcclConnector to transmit KV caches between vLLM instances.
|
||||
# This instance is the prefill node (kv_producer, rank 0).
|
||||
# The number of parallel instances for KV cache transfer is set to 2,
|
||||
# as required for P2pNcclConnector.
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="P2pNcclConnector",
|
||||
kv_role="kv_producer",
|
||||
kv_rank=0,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
|
||||
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
|
||||
# memory. You may need to adjust the value to fit your GPU.
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
llm.generate(prompts, sampling_params)
|
||||
print("Prefill node is finished.")
|
||||
prefill_done.set()
|
||||
|
||||
# To keep the prefill node running in case the decode node is not done;
|
||||
# otherwise, the script might exit prematurely, causing incomplete decoding.
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Script stopped by user.")
|
||||
|
||||
|
||||
def run_decode(prefill_done):
|
||||
# We use GPU 1 for decode node.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"Hi, your name is",
|
||||
"Tell me a very long story",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
||||
|
||||
# Using P2pNcclConnector to transmit KV caches between vLLM instances.
|
||||
# This instance is the decode node (kv_consumer, rank 1).
|
||||
# The number of parallel instances for KV cache transfer is set to 2,
|
||||
# as required for P2pNcclConnector.
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="P2pNcclConnector",
|
||||
kv_role="kv_consumer",
|
||||
kv_rank=1,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
|
||||
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
|
||||
# memory. You may need to adjust the value to fit your GPU.
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
# Wait for the producer to start the pipe
|
||||
print("Waiting for prefill node to finish...")
|
||||
prefill_done.wait()
|
||||
|
||||
# At this point when the prefill_done is set, the kv-cache should have been
|
||||
# transferred to this decode node, so we can start decoding.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
def main():
|
||||
prefill_done = Event()
|
||||
prefill_process = Process(target=run_prefill, args=(prefill_done,))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done,))
|
||||
|
||||
# Start prefill node
|
||||
prefill_process.start()
|
||||
|
||||
# Start decode node
|
||||
decode_process.start()
|
||||
|
||||
# Terminate the prefill node when decode is finished
|
||||
decode_process.join()
|
||||
prefill_process.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
133
examples/offline_inference/encoder_decoder_multimodal.py
Normal file
133
examples/offline_inference/encoder_decoder_multimodal.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference with
|
||||
the explicit/implicit prompt format on enc-dec LMMs for text generation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple
|
||||
|
||||
from vllm import LLM, EngineArgs, PromptType, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompts: Sequence[PromptType]
|
||||
|
||||
|
||||
def run_whisper():
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
max_model_len=448,
|
||||
max_num_seqs=16,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
dtype="half",
|
||||
)
|
||||
|
||||
prompts = [
|
||||
{ # Test implicit prompt
|
||||
"prompt": "<|startoftranscript|>",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
{ # Test explicit encoder/decoder prompt
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("winning_call").audio_and_sample_rate,
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<|startoftranscript|>",
|
||||
},
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"whisper": run_whisper,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"vision language models for text generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
"-m",
|
||||
type=str,
|
||||
default="whisper",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
raise ValueError(f"Model type {model} is not supported.")
|
||||
|
||||
req_data = model_example_map[model]()
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
prompts = req_data.prompts
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
max_tokens=64,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Decoder prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
duration = time.time() - start
|
||||
|
||||
print("Duration:", duration)
|
||||
print("RPS:", len(prompts) / duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,30 @@
|
||||
# KV Load Failure Recovery Test
|
||||
|
||||
This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`.
|
||||
|
||||
It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output.
|
||||
|
||||
## Files
|
||||
|
||||
- `prefill_example.py` – performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`).
|
||||
- `decode_example.py` – performs the decode stage. Accepts:
|
||||
- `--simulate-failure`: simulates KV load failure using a custom connector.
|
||||
- `--async-load`: enables asynchronous KV loading mode.
|
||||
- `load_recovery_example_connector.py` – defines `LoadRecoveryExampleConnector`, a subclass of `ExampleConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request.
|
||||
- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages:
|
||||
1. Normal decode (baseline).
|
||||
2. Decode with simulated sync KV load failure.
|
||||
3. Decode with simulated async KV load failure.
|
||||
|
||||
Finally, it compares the output of the baseline with the recovered outputs to verify correctness.
|
||||
|
||||
## How It Works
|
||||
|
||||
- The test dynamically loads `LoadRecoveryExampleConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector.
|
||||
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode.
|
||||
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
./run.sh
|
||||
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
|
||||
def read_prompts():
|
||||
"""Read prompts from prefill_output.txt"""
|
||||
prompts = []
|
||||
try:
|
||||
with open("prefill_output.txt") as f:
|
||||
for line in f:
|
||||
prompts.append(line.strip())
|
||||
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
|
||||
return prompts
|
||||
except FileNotFoundError:
|
||||
print("Error: prefill_output.txt file not found")
|
||||
exit(-1)
|
||||
|
||||
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--simulate-failure", action="store_true", help="Simulate KV load failure."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-load", action="store_true", help="Simulate async KV load"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.simulate_failure:
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="LoadRecoveryExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage",
|
||||
"async_load": args.async_load,
|
||||
},
|
||||
kv_connector_module_path="load_recovery_example_connector",
|
||||
)
|
||||
out_file = (
|
||||
"async_decode_recovered_output.txt"
|
||||
if args.async_load
|
||||
else "sync_decode_recovered_output.txt"
|
||||
)
|
||||
else:
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="ExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage",
|
||||
},
|
||||
)
|
||||
out_file = "decode_output.txt"
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=ktc,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
sep_str = "-" * 30
|
||||
with open(out_file, "w", encoding="utf-8") as f:
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
|
||||
print(out_str)
|
||||
print(sep_str)
|
||||
f.write(out_str)
|
||||
f.write(sep_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
|
||||
ExampleConnector,
|
||||
ExampleConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadRecoveryExampleConnectorMetadata(ExampleConnectorMetadata):
|
||||
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_base(cls, base: ExampleConnectorMetadata):
|
||||
return cls(requests=base.requests)
|
||||
|
||||
|
||||
class LoadRecoveryExampleConnector(ExampleConnector):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"async_load", False
|
||||
)
|
||||
self._invalid_block_ids: set = None
|
||||
self._seen_requests: set = set()
|
||||
self._req_to_block_ids: dict[str, list[int]] = dict()
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, LoadRecoveryExampleConnectorMetadata)
|
||||
index, failed_request = next(
|
||||
(
|
||||
(i, x)
|
||||
for i, x in enumerate(connector_metadata.requests)
|
||||
if not x.is_store
|
||||
),
|
||||
(None, None),
|
||||
)
|
||||
if index is not None:
|
||||
del connector_metadata.requests[index]
|
||||
self._invalid_block_ids = set(
|
||||
(
|
||||
failed_request.slot_mapping[:: self._block_size] // self._block_size
|
||||
).tolist()
|
||||
)
|
||||
logger.info(
|
||||
"Simulating failure to load all KV blocks for the "
|
||||
"first load request. Total blocks: %d",
|
||||
len(self._invalid_block_ids),
|
||||
)
|
||||
super().bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self._invalid_block_ids = None
|
||||
super().clear_connector_metadata()
|
||||
|
||||
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
|
||||
if self._async_load and forward_context.attn_metadata is None:
|
||||
# Bypass sanity check in super().start_load_kv
|
||||
forward_context.attn_metadata = "None"
|
||||
|
||||
super().start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
if self._async_load:
|
||||
meta = self._get_connector_metadata()
|
||||
assert isinstance(meta, LoadRecoveryExampleConnectorMetadata)
|
||||
if meta.req_to_block_ids:
|
||||
return None, set(meta.req_to_block_ids)
|
||||
|
||||
return None, None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
return self._invalid_block_ids
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: Request,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
if request.request_id in self._seen_requests:
|
||||
return 0, False
|
||||
|
||||
self._seen_requests.add(request.request_id)
|
||||
|
||||
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
|
||||
return num_tokens, self._async_load and num_tokens > 0
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
super().update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
|
||||
if num_external_tokens > 0:
|
||||
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> KVConnectorMetadata:
|
||||
if not self._async_load:
|
||||
base = super().build_connector_meta(scheduler_output)
|
||||
meta = LoadRecoveryExampleConnectorMetadata.from_base(base)
|
||||
else:
|
||||
meta = LoadRecoveryExampleConnectorMetadata()
|
||||
if self._requests_need_load:
|
||||
for req_id, request in self._requests_need_load.items():
|
||||
meta.add_request(
|
||||
token_ids=request.prompt_token_ids,
|
||||
block_ids=self._req_to_block_ids[req_id],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[],
|
||||
)
|
||||
# Clear state
|
||||
self._requests_need_load.clear()
|
||||
meta.req_to_block_ids = self._req_to_block_ids
|
||||
self._req_to_block_ids = dict()
|
||||
return meta
|
||||
@@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
|
||||
def read_prompts():
|
||||
context = "Hi " * 1000
|
||||
context2 = "Hey " * 500
|
||||
return [
|
||||
context + "Hello, my name is",
|
||||
context + "The capital of France is",
|
||||
context2 + "Your name is",
|
||||
context2 + "The capital of China is",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="ExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
),
|
||||
) # , max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
new_prompts = []
|
||||
print("-" * 30)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
new_prompts.append(prompt + generated_text)
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 30)
|
||||
|
||||
# Write new_prompts to prefill_output.txt
|
||||
with open("prefill_output.txt", "w") as f:
|
||||
for prompt in new_prompts:
|
||||
f.write(prompt + "\n")
|
||||
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
33
examples/offline_inference/kv_load_failure_recovery/run.sh
Executable file
33
examples/offline_inference/kv_load_failure_recovery/run.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Constants
|
||||
SHARED_STORAGE_DIR="local_storage"
|
||||
PREFILL_OUTPUT="prefill_output.txt"
|
||||
DECODE_OUTPUT="decode_output.txt"
|
||||
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt"
|
||||
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"
|
||||
|
||||
# Cleanup
|
||||
rm -rf "$SHARED_STORAGE_DIR"
|
||||
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
|
||||
|
||||
# Run inference examples
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load
|
||||
|
||||
# Compare outputs
|
||||
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then
|
||||
echo "❌ Outputs differ: sync recovery failed."
|
||||
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then
|
||||
echo "❌ Outputs differ: async recovery failed."
|
||||
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ Outputs match: recovery successful."
|
||||
74
examples/offline_inference/llm_engine_example.py
Normal file
74
examples/offline_inference/llm_engine_example.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates using the `LLMEngine`
|
||||
for processing prompts with various sampling parameters.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
||||
"""Create a list of test prompts with their sampling parameters."""
|
||||
return [
|
||||
(
|
||||
"A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1),
|
||||
),
|
||||
(
|
||||
"To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2),
|
||||
),
|
||||
(
|
||||
"What is the meaning of life?",
|
||||
SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
print("-" * 50)
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print(request_output)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
"""Initialize the LLMEngine from the command line arguments."""
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using the LLMEngine class directly"
|
||||
)
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine(args)
|
||||
test_prompts = create_test_prompts()
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
98
examples/offline_inference/llm_engine_reset_kv.py
Normal file
98
examples/offline_inference/llm_engine_reset_kv.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates preempt requests when using the `LLMEngine`
|
||||
for processing prompts with various sampling parameters.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
||||
"""Create a list of test prompts with their sampling parameters."""
|
||||
return [
|
||||
(
|
||||
"A robot may not injure a human being " * 50,
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
|
||||
),
|
||||
),
|
||||
(
|
||||
"A robot may not injure a human being " * 50,
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
|
||||
),
|
||||
),
|
||||
(
|
||||
"To be or not to be,",
|
||||
SamplingParams(
|
||||
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
|
||||
),
|
||||
),
|
||||
(
|
||||
"What is the meaning of life?",
|
||||
SamplingParams(
|
||||
n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
print("-" * 50)
|
||||
step_id = 0
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
print("-" * 50)
|
||||
import os
|
||||
|
||||
print(f"Step {step_id} (pid={os.getpid()})")
|
||||
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
request_id += 1
|
||||
|
||||
if step_id == 10:
|
||||
print(f"Resetting prefix cache at {step_id}")
|
||||
engine.reset_prefix_cache(reset_running_requests=True)
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print("-" * 50)
|
||||
print(request_output)
|
||||
print("-" * 50)
|
||||
step_id += 1
|
||||
|
||||
|
||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
"""Initialize the LLMEngine from the command line arguments."""
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using the LLMEngine class directly"
|
||||
)
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine(args)
|
||||
test_prompts = create_test_prompts()
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
94
examples/offline_inference/load_sharded_state.py
Normal file
94
examples/offline_inference/load_sharded_state.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Validates the loading of a model saved with the sharded_state format.
|
||||
This script demonstrates how to load a model that was previously saved
|
||||
using save_sharded_state.py and validates it by running inference.
|
||||
Example usage:
|
||||
(First need to save a sharded_state mode)
|
||||
|
||||
python save_sharded_state.py \
|
||||
--model /path/to/load \
|
||||
--quantization deepspeedfp \
|
||||
--tensor-parallel-size 8 \
|
||||
--output /path/to/save/sharded/model
|
||||
|
||||
python load_sharded_state.py \
|
||||
--model /path/to/saved/sharded/model \
|
||||
--load-format sharded_state \
|
||||
--quantization deepspeedfp \
|
||||
--tensor-parallel-size 8 \
|
||||
--prompt "Hello, my name is" \
|
||||
--max-tokens 50
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine arguments
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
# Override default load_format for clarity
|
||||
parser.set_defaults(load_format="sharded_state")
|
||||
|
||||
# Add validation arguments
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="Hello, world!", help="Prompt for validation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.7, help="Sampling temperature"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p", type=float, default=1.0, help="Top-p sampling parameter"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
print(
|
||||
f"Loading model from {engine_args.model} using format {engine_args.load_format}"
|
||||
)
|
||||
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
|
||||
|
||||
# Load the model using engine args
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
# Prepare sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
max_tokens=args.max_tokens,
|
||||
)
|
||||
|
||||
print("\nRunning inference:")
|
||||
print(f"Prompt: {args.prompt}")
|
||||
|
||||
# Generate completion
|
||||
outputs = llm.generate(args.prompt, sampling_params)
|
||||
|
||||
# Display generated text
|
||||
print("\nGenerated outputs:")
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print("-" * 50)
|
||||
print(f"Full output: {args.prompt}{generated_text}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
142
examples/offline_inference/logits_processor/custom.py
Normal file
142
examples/offline_inference/logits_processor/custom.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""This example demonstrates instantiating vLLM with a custom logits processor
|
||||
class object.
|
||||
|
||||
For a basic example of implementing a custom logits processor, see
|
||||
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.
|
||||
|
||||
For testing purposes, a dummy logits processor is employed which, if
|
||||
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
|
||||
will mask out all tokens except `target_token`.
|
||||
|
||||
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
||||
`target_token`, and for these requests - and *only* these requests - we
|
||||
expect the `target_token` to be decoded in each step, yielding an output
|
||||
similar to that shown below:
|
||||
|
||||
Generated Outputs:
|
||||
------------------------------------------------------------
|
||||
Prompt: 'Hello, my name is'
|
||||
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The president of the United States is'
|
||||
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The capital of France is'
|
||||
Output: ' also also also also also also also also also also also also also
|
||||
also also also'
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The future of AI is'
|
||||
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||
------------------------------------------------------------
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
)
|
||||
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||
|
||||
|
||||
# Hypothetical custom logits processor
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"target_token value {target_token} {type(target_token)} is not int"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
self.req_info: dict[int, int] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
def extract_extra_arg(params: SamplingParams) -> int | None:
|
||||
self.validate_params(params)
|
||||
return params.extra_args and params.extra_args.get("target_token")
|
||||
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
# This function returns the LP's per-request state based on the
|
||||
# request details, or None if this LP does not apply to the
|
||||
# request.
|
||||
lambda params, _, __: extract_extra_arg(params),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.req_info:
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
cols = torch.tensor(
|
||||
list(self.req_info.values()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
rows = torch.tensor(
|
||||
list(self.req_info.keys()), dtype=torch.long, device=logits.device
|
||||
)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float("-inf")
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
||||
SamplingParams(temperature=0.0),
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
||||
SamplingParams(temperature=0.0),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
logits_processors=[DummyLogitsProcessor],
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params_list)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
152
examples/offline_inference/logits_processor/custom_req.py
Normal file
152
examples/offline_inference/logits_processor/custom_req.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""This example demonstrates wrapping a request-level logits processor to be
|
||||
compatible with vLLM's batch-level logits processing
|
||||
|
||||
For demo purposes, a dummy logits processor is employed which, if
|
||||
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
|
||||
will mask out all tokens except `target_token`. This logits processor can be
|
||||
applied to a vector of logits associated with a single decode step for a single
|
||||
request. The logits processor cannot be applied to a request which does not
|
||||
pass in a `target_token` custom argument.
|
||||
|
||||
The request-level dummy logits processor is wrapped to create a batch-level
|
||||
logits processor, which can apply the logits processor to output logits from
|
||||
all requests in the persistent batch in a given decode step. For requests which
|
||||
do not provide a `target_token` argument, the corresponding row of `logits`
|
||||
will not be modified.
|
||||
|
||||
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
||||
`target_token`, and for these requests - and *only* these requests - we
|
||||
expect the `target_token` to be decoded in each step, yielding an output
|
||||
similar to that shown below:
|
||||
|
||||
Generated Outputs:
|
||||
------------------------------------------------------------
|
||||
Prompt: 'Hello, my name is'
|
||||
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The president of the United States is'
|
||||
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The capital of France is'
|
||||
Output: ' also also also also also also also also also also also also also
|
||||
also also also'
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The future of AI is'
|
||||
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||
------------------------------------------------------------
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
AdapterLogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DummyPerReqLogitsProcessor:
|
||||
"""The request-level logits processor masks out all logits except the
|
||||
token id identified by `target_token`"""
|
||||
|
||||
def __init__(self, target_token: int) -> None:
|
||||
"""Specify `target_token`"""
|
||||
self.target_token = target_token
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
output_ids: list[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
val_to_keep = logits[self.target_token].item()
|
||||
logits[:] = float("-inf")
|
||||
logits[self.target_token] = val_to_keep
|
||||
return logits
|
||||
|
||||
|
||||
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of wrapping a fake request-level logit processor to create a
|
||||
batch-level logits processor"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(f"target_token value {target_token} is not int")
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
def new_req_logits_processor(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> RequestLogitsProcessor | None:
|
||||
"""This method returns a new request-level logits processor, customized
|
||||
to the `target_token` value associated with a particular request.
|
||||
|
||||
Returns None if the logits processor should not be applied to the
|
||||
particular request. To use the logits processor the request must have
|
||||
a "target_token" custom argument with an integer value.
|
||||
|
||||
Args:
|
||||
params: per-request sampling params
|
||||
|
||||
Returns:
|
||||
`Callable` request logits processor, or None
|
||||
"""
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
||||
SamplingParams(temperature=0.0),
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
||||
SamplingParams(temperature=0.0),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
logits_processors=[WrappedPerReqLogitsProcessor],
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params_list)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
164
examples/offline_inference/logits_processor/custom_req_init.py
Normal file
164
examples/offline_inference/logits_processor/custom_req_init.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""This example demonstrates a special case of wrapping a request-level logits
|
||||
processor, namely the case where it is necessary to utilize engine config or
|
||||
environment info passed to the constructor. The subclass must override the
|
||||
wrapper base class `__init__()` method to access the engine config, the device
|
||||
identifier, or the flag which indicates whether pinned memory is available.
|
||||
|
||||
For demo purposes, a request-level dummy logits processor is employed which
|
||||
causes the same token (`target_token`) to be decoded in each step. The
|
||||
request-level dummy logits processor is wrapped to create a batch-level logits
|
||||
processor, which can apply the logits processor to output logits from all
|
||||
requests in the persistent batch in a given decode step.
|
||||
|
||||
The wrapped dummy logits processor below models a scenario where we must
|
||||
disable the logits processor on non-"cuda" platforms. The wrapper base class
|
||||
`__init__()` is overridden in order to check this condition and set a flag.
|
||||
|
||||
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
||||
`target_token`, and for these requests - and *only* these requests - we
|
||||
expect that on a "cuda" device the output will look something like:
|
||||
|
||||
Generated Outputs:
|
||||
------------------------------------------------------------
|
||||
Prompt: 'Hello, my name is'
|
||||
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The president of the United States is'
|
||||
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The capital of France is'
|
||||
Output: ' also also also also also also also also also also also also also
|
||||
also also also'
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The future of AI is'
|
||||
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||
------------------------------------------------------------
|
||||
|
||||
which indicates that the logits processor is running. However, on a non-"cuda"
|
||||
device, the first and third requests would not repeat the same token.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
AdapterLogitsProcessor,
|
||||
RequestLogitsProcessor,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DummyPerReqLogitsProcessor:
|
||||
"""The request-level logits processor masks out all logits except the
|
||||
token id identified by `target_token`"""
|
||||
|
||||
def __init__(self, target_token: int) -> None:
|
||||
"""Specify `target_token`"""
|
||||
self.target_token = target_token
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
output_ids: list[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
val_to_keep = logits[self.target_token].item()
|
||||
logits[:] = float("-inf")
|
||||
logits[self.target_token] = val_to_keep
|
||||
return logits
|
||||
|
||||
|
||||
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of overriding the wrapper class `__init__()` in order to utilize
|
||||
info about the device type"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token = params.extra_args and params.extra_args.get("target_token")
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"`target_token` has to be an integer, got {target_token}."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
super().__init__(vllm_config, device, is_pin_memory)
|
||||
self.is_cuda = device.type == "cuda"
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
def new_req_logits_processor(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> RequestLogitsProcessor | None:
|
||||
"""This method returns a new request-level logits processor, customized
|
||||
to the `target_token` value associated with a particular request.
|
||||
|
||||
Returns None if the logits processor should not be applied to the
|
||||
particular request. To use the logits processor the request must have
|
||||
a "target_token" custom argument with an integer value, and the device
|
||||
must be "cuda"-type
|
||||
|
||||
Args:
|
||||
params: per-request sampling params
|
||||
|
||||
Returns:
|
||||
`Callable` request logits processor, or None
|
||||
"""
|
||||
if (
|
||||
not self.is_cuda
|
||||
or (
|
||||
target_token := params.extra_args
|
||||
and params.extra_args.get("target_token")
|
||||
)
|
||||
is None
|
||||
):
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
||||
SamplingParams(temperature=0.0),
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
||||
SamplingParams(temperature=0.0),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
logits_processors=[WrappedPerReqLogitsProcessor],
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params_list)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
examples/offline_inference/lora_with_quantization_inference.py
Normal file
127
examples/offline_inference/lora_with_quantization_inference.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use LoRA with different quantization techniques
|
||||
for offline inference.
|
||||
|
||||
Requires HuggingFace credentials for access.
|
||||
"""
|
||||
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def create_test_prompts(
|
||||
lora_path: str,
|
||||
) -> list[tuple[str, SamplingParams, LoRARequest | None]]:
|
||||
return [
|
||||
# this is an example of using quantization without LoRA
|
||||
(
|
||||
"My name is",
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
None,
|
||||
),
|
||||
# the next three examples use quantization with LoRA
|
||||
(
|
||||
"my name is",
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("lora-test-1", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of USA is",
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("lora-test-2", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"The capital of France is",
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("lora-test-3", 1, lora_path),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(
|
||||
engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]],
|
||||
):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||
engine.add_request(
|
||||
str(request_id), prompt, sampling_params, lora_request=lora_request
|
||||
)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print("----------------------------------------------------")
|
||||
print(f"Prompt: {request_output.prompt}")
|
||||
print(f"Output: {request_output.outputs[0].text}")
|
||||
|
||||
|
||||
def initialize_engine(
|
||||
model: str, quantization: str, lora_repo: str | None
|
||||
) -> LLMEngine:
|
||||
"""Initialize the LLMEngine."""
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
quantization=quantization,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_loras=4,
|
||||
)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
|
||||
test_configs = [
|
||||
# QLoRA (https://arxiv.org/abs/2305.14314)
|
||||
{
|
||||
"name": "qlora_inference_example",
|
||||
"model": "huggyllama/llama-7b",
|
||||
"quantization": "bitsandbytes",
|
||||
"lora_repo": "timdettmers/qlora-flan-7b",
|
||||
},
|
||||
{
|
||||
"name": "AWQ_inference_with_lora_example",
|
||||
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
"quantization": "awq",
|
||||
"lora_repo": "jashing/tinyllama-colorist-lora",
|
||||
},
|
||||
{
|
||||
"name": "GPTQ_inference_with_lora_example",
|
||||
"model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
"quantization": "gptq",
|
||||
"lora_repo": "jashing/tinyllama-colorist-lora",
|
||||
},
|
||||
]
|
||||
|
||||
for test_config in test_configs:
|
||||
print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
|
||||
engine = initialize_engine(
|
||||
test_config["model"], test_config["quantization"], test_config["lora_repo"]
|
||||
)
|
||||
lora_path = snapshot_download(repo_id=test_config["lora_repo"])
|
||||
test_prompts = create_test_prompts(lora_path)
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
del engine
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
50
examples/offline_inference/metrics.py
Normal file
50
examples/offline_inference/metrics.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Vector
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", disable_log_stats=False)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Dump all metrics
|
||||
for metric in llm.get_metrics():
|
||||
if isinstance(metric, Gauge):
|
||||
print(f"{metric.name} (gauge) = {metric.value}")
|
||||
elif isinstance(metric, Counter):
|
||||
print(f"{metric.name} (counter) = {metric.value}")
|
||||
elif isinstance(metric, Vector):
|
||||
print(f"{metric.name} (vector) = {metric.values}")
|
||||
elif isinstance(metric, Histogram):
|
||||
print(f"{metric.name} (histogram)")
|
||||
print(f" sum = {metric.sum}")
|
||||
print(f" count = {metric.count}")
|
||||
for bucket_le, value in metric.buckets.items():
|
||||
print(f" {bucket_le} = {value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
186
examples/offline_inference/mistral-small.py
Normal file
186
examples/offline_inference/mistral-small.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
import argparse
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
|
||||
# This script is an offline demo for running Mistral-Small-3.1
|
||||
#
|
||||
# If you want to run a server/client setup, please follow this code:
|
||||
#
|
||||
# - Server:
|
||||
#
|
||||
# ```bash
|
||||
# # Mistral format
|
||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
|
||||
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||
#
|
||||
# # HF format
|
||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||
# ```
|
||||
#
|
||||
# - Client:
|
||||
#
|
||||
# ```bash
|
||||
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
|
||||
# --header 'Content-Type: application/json' \
|
||||
# --header 'Authorization: Bearer token' \
|
||||
# --data '{
|
||||
# "model": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
# "messages": [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {"type" : "text", "text": "Describe this image in detail please."},
|
||||
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
|
||||
# {"type" : "text", "text": "and this one as well. Answer in French."},
|
||||
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
# }'
|
||||
# ```
|
||||
#
|
||||
# Usage:
|
||||
# python demo.py simple
|
||||
# python demo.py advanced
|
||||
|
||||
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
|
||||
# These scripts have been tested on 2x L40 GPUs
|
||||
|
||||
|
||||
def run_simple_demo(args: argparse.Namespace):
|
||||
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
sampling_params = SamplingParams(max_tokens=8192)
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tokenizer_mode="mistral" if args.format == "mistral" else "auto",
|
||||
config_format="mistral" if args.format == "mistral" else "auto",
|
||||
load_format="mistral" if args.format == "mistral" else "auto",
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||
)
|
||||
|
||||
prompt = "Describe this image in one sentence."
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_pil",
|
||||
"image_pil": ImageAsset("cherry_blossom").pil_image,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
outputs = llm.chat(messages, sampling_params=sampling_params)
|
||||
print("-" * 50)
|
||||
print(outputs[0].outputs[0].text)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def run_advanced_demo(args: argparse.Namespace):
|
||||
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
max_img_per_msg = 3
|
||||
max_tokens_per_img = 4096
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tokenizer_mode="mistral" if args.format == "mistral" else "auto",
|
||||
config_format="mistral" if args.format == "mistral" else "auto",
|
||||
load_format="mistral" if args.format == "mistral" else "auto",
|
||||
limit_mm_per_prompt={"image": max_img_per_msg},
|
||||
max_model_len=max_img_per_msg * max_tokens_per_img,
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||
)
|
||||
|
||||
prompt = "Describe the following image."
|
||||
|
||||
url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
|
||||
url_2 = "https://picsum.photos/seed/picsum/200/300"
|
||||
url_3 = "https://picsum.photos/id/32/512/512"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": url_1}},
|
||||
{"type": "image_url", "image_url": {"url": url_2}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The images show nature.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "More details please and answer only in French!.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": url_3}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
outputs = llm.chat(messages=messages, sampling_params=sampling_params)
|
||||
print("-" * 50)
|
||||
print(outputs[0].outputs[0].text)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a demo in simple or advanced mode."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
choices=["simple", "advanced"],
|
||||
help="Specify the demo mode: 'simple' or 'advanced'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
choices=["mistral", "hf"],
|
||||
default="mistral",
|
||||
help="Specify the format of the model to load.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-mm-processor-cache",
|
||||
action="store_true",
|
||||
help="If True, disables caching of multi-modal processor.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.mode == "simple":
|
||||
print("Running simple demo...")
|
||||
run_simple_demo(args)
|
||||
elif args.mode == "advanced":
|
||||
print("Running advanced demo...")
|
||||
run_advanced_demo(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
72
examples/offline_inference/mlpspeculator.py
Normal file
72
examples/offline_inference/mlpspeculator.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates the usage of text generation with an LLM model,
|
||||
comparing the performance with and without speculative decoding.
|
||||
|
||||
Note that this example is out of date and not supported in vLLM v1.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def time_generation(
|
||||
llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
|
||||
):
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||
# objects that contain the prompt, generated text, and other information.
|
||||
# Warmup first
|
||||
llm.generate(prompts, sampling_params)
|
||||
llm.generate(prompts, sampling_params)
|
||||
start = time.time()
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
end = time.time()
|
||||
print("-" * 50)
|
||||
print(title)
|
||||
print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
template = (
|
||||
"Below is an instruction that describes a task. Write a response "
|
||||
"that appropriately completes the request.\n\n### Instruction:\n{}"
|
||||
"\n\n### Response:\n"
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Write about the president of the United States.",
|
||||
]
|
||||
prompts = [template.format(prompt) for prompt in prompts]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
|
||||
|
||||
# Create an LLM without spec decoding
|
||||
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
|
||||
|
||||
time_generation(llm, prompts, sampling_params, "Without speculation")
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
|
||||
# Create an LLM with spec decoding
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-13b-chat-hf",
|
||||
speculative_config={
|
||||
"model": "ibm-ai-platform/llama-13b-accelerator",
|
||||
},
|
||||
)
|
||||
|
||||
time_generation(llm, prompts, sampling_params, "With speculation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
106
examples/offline_inference/multilora_inference.py
Normal file
106
examples/offline_inference/multilora_inference.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use the multi-LoRA functionality
|
||||
for offline inference.
|
||||
|
||||
Requires HuggingFace credentials for access to Llama2.
|
||||
"""
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def create_test_prompts(
|
||||
lora_path: str,
|
||||
) -> list[tuple[str, SamplingParams, LoRARequest | None]]:
|
||||
"""Create a list of test prompts with their sampling parameters.
|
||||
|
||||
2 requests for base model, 4 requests for the LoRA. We define 2
|
||||
different LoRA adapters (using the same model for demo purposes).
|
||||
Since we also set `max_loras=1`, the expectation is that the requests
|
||||
with the second LoRA adapter will be run after all requests with the
|
||||
first adapter have finished.
|
||||
"""
|
||||
return [
|
||||
(
|
||||
"A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"To be or not to be,",
|
||||
SamplingParams(
|
||||
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
|
||||
),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("sql-lora", 1, lora_path),
|
||||
),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
|
||||
LoRARequest("sql-lora2", 2, lora_path),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(
|
||||
engine: LLMEngine,
|
||||
test_prompts: list[tuple[str, SamplingParams, LoRARequest | None]],
|
||||
):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
print("-" * 50)
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||
engine.add_request(
|
||||
str(request_id), prompt, sampling_params, lora_request=lora_request
|
||||
)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print(request_output)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def initialize_engine() -> LLMEngine:
|
||||
"""Initialize the LLMEngine."""
|
||||
# max_loras: controls the number of LoRAs that can be used in the same
|
||||
# batch. Larger numbers will cause higher memory usage, as each LoRA
|
||||
# slot requires its own preallocated tensor.
|
||||
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
|
||||
# numbers will cause higher memory usage. If you know that all LoRAs will
|
||||
# use the same rank, it is recommended to set this as low as possible.
|
||||
# max_cpu_loras: controls the size of the CPU LoRA cache.
|
||||
engine_args = EngineArgs(
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_num_seqs=256,
|
||||
)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine()
|
||||
lora_path = snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
|
||||
test_prompts = create_test_prompts(lora_path)
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
276
examples/offline_inference/openai_batch/README.md
Normal file
276
examples/offline_inference/openai_batch/README.md
Normal file
@@ -0,0 +1,276 @@
|
||||
# Offline Inference with the OpenAI Batch file format
|
||||
|
||||
```{important}
|
||||
This is a guide to performing batch inference using the OpenAI batch file format, **not** the complete Batch (REST) API.
|
||||
```
|
||||
|
||||
## File Format
|
||||
|
||||
The OpenAI batch file format consists of a series of json objects on new lines.
|
||||
|
||||
[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl)
|
||||
|
||||
Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details.
|
||||
|
||||
```{note}
|
||||
We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` endpoints (completions coming soon).
|
||||
```
|
||||
|
||||
## Pre-requisites
|
||||
|
||||
* The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`.
|
||||
* Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens)
|
||||
* Install the token on your machine (Run `huggingface-cli login`).
|
||||
* Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions.
|
||||
|
||||
## Example 1: Running with a local file
|
||||
|
||||
### Step 1: Create your batch file
|
||||
|
||||
To follow along with this example, you can download the example batch, or create your own batch file in your working directory.
|
||||
|
||||
```bash
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl
|
||||
```
|
||||
|
||||
Once you've created your batch file it should look like this
|
||||
|
||||
```bash
|
||||
cat offline_inference/openai_batch/openai_example_batch.jsonl
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
|
||||
```
|
||||
|
||||
### Step 2: Run the batch
|
||||
|
||||
The batch running tool is designed to be used from the command line.
|
||||
|
||||
You can run the batch with the following command, which will write its results to a file called `results.jsonl`
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.run_batch \
|
||||
-i offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
or use command-line:
|
||||
|
||||
```bash
|
||||
vllm run-batch \
|
||||
-i offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
### Step 3: Check your results
|
||||
|
||||
You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl`
|
||||
|
||||
```bash
|
||||
cat results.jsonl
|
||||
{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null}
|
||||
{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null}
|
||||
```
|
||||
|
||||
## Example 2: Using remote files
|
||||
|
||||
The batch runner supports remote input and output urls that are accessible via http/https.
|
||||
|
||||
For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl`, you can run
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.run_batch \
|
||||
-i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
or use command-line:
|
||||
|
||||
```bash
|
||||
vllm run-batch \
|
||||
-i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
## Example 3: Integrating with AWS S3
|
||||
|
||||
To integrate with cloud blob storage, we recommend using presigned urls.
|
||||
|
||||
[Learn more about S3 presigned urls here]
|
||||
|
||||
### Additional prerequisites
|
||||
|
||||
* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html).
|
||||
* The `awscli` package (Run `pip install awscli`) to configure your credentials and interactively use s3.
|
||||
* [Configure your credentials](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html).
|
||||
* The `boto3` python package (Run `pip install boto3`) to generate presigned urls.
|
||||
|
||||
### Step 1: Upload your input script
|
||||
|
||||
To follow along with this example, you can download the example batch, or create your own batch file in your working directory.
|
||||
|
||||
```bash
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl
|
||||
```
|
||||
|
||||
Once you've created your batch file it should look like this
|
||||
|
||||
```bash
|
||||
cat offline_inference/openai_batch/openai_example_batch.jsonl
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
|
||||
```
|
||||
|
||||
Now upload your batch file to your S3 bucket.
|
||||
|
||||
```bash
|
||||
aws s3 cp offline_inference/openai_batch/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl
|
||||
```
|
||||
|
||||
### Step 2: Generate your presigned urls
|
||||
|
||||
Presigned urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names.
|
||||
|
||||
(The script is adapted from <https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py>)
|
||||
|
||||
```python
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
def generate_presigned_url(s3_client, client_method, method_parameters, expires_in):
|
||||
"""
|
||||
Generate a presigned Amazon S3 URL that can be used to perform an action.
|
||||
|
||||
:param s3_client: A Boto3 Amazon S3 client.
|
||||
:param client_method: The name of the client method that the URL performs.
|
||||
:param method_parameters: The parameters of the specified client method.
|
||||
:param expires_in: The number of seconds the presigned URL is valid for.
|
||||
:return: The presigned URL.
|
||||
"""
|
||||
try:
|
||||
url = s3_client.generate_presigned_url(
|
||||
ClientMethod=client_method,
|
||||
Params=method_parameters,
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
except ClientError:
|
||||
raise
|
||||
return url
|
||||
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
input_url = generate_presigned_url(
|
||||
s3_client,
|
||||
"get_object",
|
||||
{"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"},
|
||||
expires_in=3600,
|
||||
)
|
||||
output_url = generate_presigned_url(
|
||||
s3_client,
|
||||
"put_object",
|
||||
{"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"},
|
||||
expires_in=3600,
|
||||
)
|
||||
print(f"{input_url=}")
|
||||
print(f"{output_url=}")
|
||||
```
|
||||
|
||||
This script should output
|
||||
|
||||
```text
|
||||
input_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091'
|
||||
output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091'
|
||||
```
|
||||
|
||||
### Step 3: Run the batch runner using your presigned urls
|
||||
|
||||
You can now run the batch runner, using the urls generated in the previous section.
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.run_batch \
|
||||
-i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
|
||||
-o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
|
||||
--model --model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
or use command-line:
|
||||
|
||||
```bash
|
||||
vllm run-batch \
|
||||
-i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
|
||||
-o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \
|
||||
--model --model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
### Step 4: View your results
|
||||
|
||||
Your results are now on S3. You can view them in your terminal by running
|
||||
|
||||
```bash
|
||||
aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl -
|
||||
```
|
||||
|
||||
## Example 4: Using embeddings endpoint
|
||||
|
||||
### Additional prerequisites
|
||||
|
||||
* Ensure you are using `vllm >= 0.5.5`.
|
||||
|
||||
### Step 1: Create your batch file
|
||||
|
||||
Add embedding requests to your batch file. The following is an example:
|
||||
|
||||
```text
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are an unhelpful assistant."}}
|
||||
```
|
||||
|
||||
You can even mix chat completion and embedding requests in the batch file, as long as the model you are using supports both chat completion and embeddings (note that all requests must use the same model).
|
||||
|
||||
### Step 2: Run the batch
|
||||
|
||||
You can run the batch using the same command as in earlier examples.
|
||||
|
||||
### Step 3: Check your results
|
||||
|
||||
You can check your results by running `cat results.jsonl`
|
||||
|
||||
```bash
|
||||
cat results.jsonl
|
||||
{"id":"vllm-db0f71f7dec244e6bce530e0b4ef908b","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-3580bf4d4ae54d52b67eee266a6eab20","body":{"id":"embd-33ac2efa7996430184461f2e38529746","object":"list","created":444647,"model":"intfloat/e5-mistral-7b-instruct","data":[{"index":0,"object":"embedding","embedding":[0.016204833984375,0.0092010498046875,0.0018358230590820312,-0.0028228759765625,0.001422882080078125,-0.0031147003173828125,...]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0}}},"error":null}
|
||||
...
|
||||
```
|
||||
|
||||
## Example 5: Using score endpoint
|
||||
|
||||
### Additional prerequisites
|
||||
|
||||
* Ensure you are using `vllm >= 0.7.0`.
|
||||
|
||||
### Step 1: Create your batch file
|
||||
|
||||
Add score requests to your batch file. The following is an example:
|
||||
|
||||
```text
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/score", "body": {"model": "BAAI/bge-reranker-v2-m3", "text_1": "What is the capital of France?", "text_2": ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]}}
|
||||
```
|
||||
|
||||
You can mix chat completion, embedding, and score requests in the batch file, as long as the model you are using supports them all (note that all requests must use the same model).
|
||||
|
||||
### Step 2: Run the batch
|
||||
|
||||
You can run the batch using the same command as in earlier examples.
|
||||
|
||||
### Step 3: Check your results
|
||||
|
||||
You can check your results by running `cat results.jsonl`
|
||||
|
||||
```bash
|
||||
cat results.jsonl
|
||||
{"id":"vllm-f87c5c4539184f618e555744a2965987","custom_id":"request-1","response":{"status_code":200,"request_id":"vllm-batch-806ab64512e44071b37d3f7ccd291413","body":{"id":"score-4ee45236897b4d29907d49b01298cdb1","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.0010900497436523438},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null}
|
||||
{"id":"vllm-41990c51a26d4fac8419077f12871099","custom_id":"request-2","response":{"status_code":200,"request_id":"vllm-batch-73ce66379026482699f81974e14e1e99","body":{"id":"score-13f2ffe6ba40460fbf9f7f00ad667d75","object":"list","created":1737847944,"model":"BAAI/bge-reranker-v2-m3","data":[{"index":0,"object":"score","score":0.001094818115234375},{"index":1,"object":"score","score":1.0}],"usage":{"prompt_tokens":37,"total_tokens":37,"completion_tokens":0,"prompt_tokens_details":null}}},"error":null}
|
||||
```
|
||||
@@ -0,0 +1,2 @@
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}}
|
||||
98
examples/offline_inference/prefix_caching.py
Normal file
98
examples/offline_inference/prefix_caching.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
# NOTE: This is just a running example. For benchmarking purpose,
|
||||
# please see benchmarks/benchmark_prefix_caching.py
|
||||
|
||||
# Common prefix.
|
||||
prefix = (
|
||||
"You are an expert school principal, skilled in effectively managing "
|
||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||
"community, joyful discovery, and life-long learning. The candidate is "
|
||||
"coming in for a first-round panel interview for a 8th grade Math "
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on these information, fulfill "
|
||||
"the following paragraph: "
|
||||
)
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
generating_prompts = [prefix + prompt for prompt in prompts]
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.0)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM without prefix caching as a baseline.
|
||||
regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4)
|
||||
|
||||
print("Results without `enable_prefix_caching`")
|
||||
|
||||
# ruff: noqa: E501
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = regular_llm.generate(generating_prompts, sampling_params)
|
||||
|
||||
regular_generated_texts = []
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
regular_generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Destroy the LLM object and free up the GPU memory.
|
||||
del regular_llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Create an LLM with prefix caching enabled.
|
||||
prefix_cached_llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.4,
|
||||
)
|
||||
|
||||
# Warmup so that the shared prompt's KV cache is computed.
|
||||
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
|
||||
|
||||
# Generate with prefix caching.
|
||||
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
|
||||
|
||||
print("Results with `enable_prefix_caching`")
|
||||
|
||||
cached_generated_texts = []
|
||||
# Print the outputs. You should see the same outputs as before.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
cached_generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Compare the results and display the speedup
|
||||
generated_same = all(
|
||||
[
|
||||
regular_generated_texts[i] == cached_generated_texts[i]
|
||||
for i in range(len(prompts))
|
||||
]
|
||||
)
|
||||
print(f"Generated answers are the same: {generated_same}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
97
examples/offline_inference/prompt_embed_inference.py
Normal file
97
examples/offline_inference/prompt_embed_inference.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates how to generate prompt embeddings using
|
||||
Hugging Face Transformers and use them as input to vLLM
|
||||
for both single and batch inference.
|
||||
|
||||
Model: meta-llama/Llama-3.2-1B-Instruct
|
||||
Note: This model is gated on Hugging Face Hub.
|
||||
You must request access to use it:
|
||||
https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
Requirements:
|
||||
- vLLM
|
||||
- transformers
|
||||
|
||||
Run:
|
||||
python examples/offline_inference/prompt_embed_inference.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
|
||||
def init_tokenizer_and_llm(model_name: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
embedding_layer = transformers_model.get_input_embeddings()
|
||||
llm = LLM(model=model_name, enable_prompt_embeds=True)
|
||||
return tokenizer, embedding_layer, llm
|
||||
|
||||
|
||||
def get_prompt_embeds(
|
||||
chat: list[dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module,
|
||||
):
|
||||
token_ids = tokenizer.apply_chat_template(
|
||||
chat, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def single_prompt_inference(
|
||||
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
||||
):
|
||||
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt_embeds": prompt_embeds,
|
||||
}
|
||||
)
|
||||
|
||||
print("\n[Single Inference Output]")
|
||||
print("-" * 30)
|
||||
for o in outputs:
|
||||
print(o.outputs[0].text)
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def batch_prompt_inference(
|
||||
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
||||
):
|
||||
chats = [
|
||||
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
||||
[{"role": "user", "content": "When is the day longest during the year?"}],
|
||||
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
|
||||
]
|
||||
|
||||
prompt_embeds_list = [
|
||||
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
||||
]
|
||||
|
||||
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
|
||||
|
||||
print("\n[Batch Inference Outputs]")
|
||||
print("-" * 30)
|
||||
for i, o in enumerate(outputs):
|
||||
print(f"Q{i + 1}: {chats[i][0]['content']}")
|
||||
print(f"A{i + 1}: {o.outputs[0].text}\n")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def main():
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
|
||||
single_prompt_inference(llm, tokenizer, embedding_layer)
|
||||
batch_prompt_inference(llm, tokenizer, embedding_layer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
40
examples/offline_inference/qwen2_5_omni/README.md
Normal file
40
examples/offline_inference/qwen2_5_omni/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Qwen2.5-Omni Offline Inference Examples
|
||||
|
||||
This folder provides several example scripts on how to inference Qwen2.5-Omni offline.
|
||||
|
||||
## Thinker Only
|
||||
|
||||
```bash
|
||||
# Audio + image + video
|
||||
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
|
||||
-q mixed_modalities
|
||||
|
||||
# Read vision and audio inputs from a single video file
|
||||
# NOTE: V1 engine does not support interleaved modalities yet.
|
||||
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
|
||||
-q use_audio_in_video
|
||||
|
||||
# Multiple audios
|
||||
python examples/offline_inference/qwen2_5_omni/only_thinker.py \
|
||||
-q multi_audios
|
||||
```
|
||||
|
||||
This script will run the thinker part of Qwen2.5-Omni, and generate text response.
|
||||
|
||||
You can also test Qwen2.5-Omni on a single modality:
|
||||
|
||||
```bash
|
||||
# Process audio inputs
|
||||
python examples/offline_inference/audio_language.py \
|
||||
--model-type qwen2_5_omni
|
||||
|
||||
# Process image inputs
|
||||
python examples/offline_inference/vision_language.py \
|
||||
--modality image \
|
||||
--model-type qwen2_5_omni
|
||||
|
||||
# Process video inputs
|
||||
python examples/offline_inference/vision_language.py \
|
||||
--modality video \
|
||||
--model-type qwen2_5_omni
|
||||
```
|
||||
170
examples/offline_inference/qwen2_5_omni/only_thinker.py
Normal file
170
examples/offline_inference/qwen2_5_omni/only_thinker.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class QueryResult(NamedTuple):
|
||||
inputs: dict
|
||||
limit_mm_per_prompt: dict[str, int]
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
default_system = (
|
||||
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||
"generating text and speech."
|
||||
)
|
||||
|
||||
|
||||
def get_mixed_modalities_query() -> QueryResult:
|
||||
question = (
|
||||
"What is recited in the audio? "
|
||||
"What is the content of this image? Why is this video funny?"
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
"<|vision_bos|><|IMAGE|><|vision_eos|>"
|
||||
"<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
"image": convert_image_mode(
|
||||
ImageAsset("cherry_blossom").pil_image, "RGB"
|
||||
),
|
||||
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_use_audio_in_video_query() -> QueryResult:
|
||||
question = (
|
||||
"Describe the content of the video, then convert what the baby say into text."
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
asset = VideoAsset(name="baby_reading", num_frames=16)
|
||||
audio = asset.get_audio(sampling_rate=16000)
|
||||
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"video": asset.np_ndarrays,
|
||||
"audio": audio,
|
||||
},
|
||||
"mm_processor_kwargs": {
|
||||
"use_audio_in_video": True,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_multi_audios_query() -> QueryResult:
|
||||
question = "Are these two audio clips the same?"
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
"<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": [
|
||||
AudioAsset("winning_call").audio_and_sample_rate,
|
||||
AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
],
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={
|
||||
"audio": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
query_map = {
|
||||
"mixed_modalities": get_mixed_modalities_query,
|
||||
"use_audio_in_video": get_use_audio_in_video_query,
|
||||
"multi_audios": get_multi_audios_query,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
model_name = "Qwen/Qwen2.5-Omni-7B"
|
||||
query_result = query_map[args.query_type]()
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=5632,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
|
||||
|
||||
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"audio language models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-type",
|
||||
"-q",
|
||||
type=str,
|
||||
default="mixed_modalities",
|
||||
choices=query_map.keys(),
|
||||
help="Query type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
170
examples/offline_inference/qwen3_omni/only_thinker.py
Normal file
170
examples/offline_inference/qwen3_omni/only_thinker.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
class QueryResult(NamedTuple):
|
||||
inputs: dict
|
||||
limit_mm_per_prompt: dict[str, int]
|
||||
|
||||
|
||||
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||
# lower-end GPUs.
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
default_system = (
|
||||
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||
"generating text and speech."
|
||||
)
|
||||
|
||||
|
||||
def get_mixed_modalities_query() -> QueryResult:
|
||||
question = (
|
||||
"What is recited in the audio? "
|
||||
"What is the content of this image? Why is this video funny?"
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
"<|vision_start|><|image_pad|><|vision_end|>"
|
||||
"<|vision_start|><|video_pad|><|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
"image": convert_image_mode(
|
||||
ImageAsset("cherry_blossom").pil_image, "RGB"
|
||||
),
|
||||
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_use_audio_in_video_query() -> QueryResult:
|
||||
question = (
|
||||
"Describe the content of the video in details, then convert what the "
|
||||
"baby say into text."
|
||||
)
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
asset = VideoAsset(name="baby_reading", num_frames=16)
|
||||
audio = asset.get_audio(sampling_rate=16000)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"video": asset.np_ndarrays,
|
||||
"audio": audio,
|
||||
},
|
||||
"mm_processor_kwargs": {
|
||||
"use_audio_in_video": True,
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={"audio": 1, "video": 1},
|
||||
)
|
||||
|
||||
|
||||
def get_multi_audios_query() -> QueryResult:
|
||||
question = "Are these two audio clips the same?"
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
"<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"audio": [
|
||||
AudioAsset("winning_call").audio_and_sample_rate,
|
||||
AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||
],
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={
|
||||
"audio": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
query_map = {
|
||||
"mixed_modalities": get_mixed_modalities_query,
|
||||
"use_audio_in_video": get_use_audio_in_video_query,
|
||||
"multi_audios": get_multi_audios_query,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
|
||||
query_result = query_map[args.query_type]()
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=12800,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
# even when all prompts are identical when running batch inference.
|
||||
sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
|
||||
|
||||
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
|
||||
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"audio language models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query-type",
|
||||
"-q",
|
||||
type=str,
|
||||
default="mixed_modalities",
|
||||
choices=query_map.keys(),
|
||||
help="Query type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
70
examples/offline_inference/qwen_1m.py
Normal file
70
examples/offline_inference/qwen_1m.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
||||
|
||||
|
||||
def load_prompt() -> str:
|
||||
# Test cases with various lengths can be found at:
|
||||
#
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
|
||||
|
||||
with urlopen(
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
|
||||
timeout=5,
|
||||
) as response:
|
||||
prompt = response.read().decode("utf-8")
|
||||
return prompt
|
||||
|
||||
|
||||
# Processing the prompt.
|
||||
def process_requests(llm: LLM, prompts: list[str]) -> None:
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.8,
|
||||
top_k=20,
|
||||
repetition_penalty=1.05,
|
||||
detokenize=True,
|
||||
max_tokens=256,
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
generated_text = output.outputs[0].text
|
||||
print(
|
||||
f"Prompt length: {len(prompt_token_ids)}, "
|
||||
f"Generated text: {generated_text!r}"
|
||||
)
|
||||
|
||||
|
||||
# Create an LLM.
|
||||
def initialize_engine() -> LLM:
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
max_model_len=1048576,
|
||||
tensor_parallel_size=4,
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=131072,
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
def main():
|
||||
llm = initialize_engine()
|
||||
prompt = load_prompt()
|
||||
process_requests(llm, [prompt])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
46
examples/offline_inference/reproducibility.py
Normal file
46
examples/offline_inference/reproducibility.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates how to achieve reproducibility in vLLM.
|
||||
|
||||
Main article: https://docs.vllm.ai/en/latest/usage/reproducibility.html
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Either:
|
||||
## Turn off multiprocessing to make the scheduling deterministic, or
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
## Enable batch invariance to get consistent results regardless of scheduling.
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def main():
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Try generating random numbers outside vLLM
|
||||
# The same number is output across runs, meaning that the random state
|
||||
# in the user code has been updated by vLLM
|
||||
print(random.randint(0, 100))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
147
examples/offline_inference/rlhf.py
Normal file
147
examples/offline_inference/rlhf.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
|
||||
|
||||
The script separates training and inference workloads onto distinct GPUs
|
||||
so that Ray can manage process placement and inter-process communication.
|
||||
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
|
||||
tensor-parallel vLLM inference engine occupies GPU 1–2.
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Load the training model on GPU 0.
|
||||
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
|
||||
and Ray placement groups.
|
||||
* Generate text from a list of prompts using the inference engine.
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using a Ray collective RPC group. Note that
|
||||
for demonstration purposes we simply zero out the weights.
|
||||
|
||||
For a production-ready implementation that supports multiple training and
|
||||
inference replicas, see the OpenRLHF framework:
|
||||
https://github.com/OpenRLHF/OpenRLHF
|
||||
|
||||
This example assumes a single-node cluster with three GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||
causes unexpected behavior.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from rlhf_utils import stateless_init_process_group
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||
# so that vLLM can manage its own device placement within the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
train_model.to("cuda:0")
|
||||
|
||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||
# be placed on GPUs 1 and 2.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
||||
ray.init()
|
||||
|
||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||
# Learn more about Ray placement groups:
|
||||
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_inference,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=0,
|
||||
)
|
||||
|
||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||
# start-up latency.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=scheduling_inference,
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
)
|
||||
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address = get_ip()
|
||||
master_port = get_open_port()
|
||||
|
||||
handle = llm.collective_rpc.remote(
|
||||
"init_weight_update_group", args=(master_address, master_port, 1, 3)
|
||||
)
|
||||
|
||||
model_update_group = stateless_init_process_group(
|
||||
master_address, master_port, 0, 3, torch.device("cuda:0")
|
||||
)
|
||||
ray.get(handle)
|
||||
|
||||
# Simulate a training step by zeroing out all model weights.
|
||||
# In a real RLHF training loop the weights would be updated using the gradient
|
||||
# from an RL objective such as PPO on a reward model.
|
||||
for name, p in train_model.named_parameters():
|
||||
p.data.zero_()
|
||||
|
||||
# Synchronize the updated weights to the inference engine.
|
||||
for name, p in train_model.named_parameters():
|
||||
dtype_name = str(p.dtype).split(".")[-1]
|
||||
handle = llm.collective_rpc.remote(
|
||||
"update_weight", args=(name, dtype_name, p.shape)
|
||||
)
|
||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||
ray.get(handle)
|
||||
|
||||
# Verify that the inference weights have been updated.
|
||||
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
||||
|
||||
# Generate text with the updated model. The output is expected to be nonsense
|
||||
# because the weights are zero.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
251
examples/offline_inference/rlhf_colocate.py
Normal file
251
examples/offline_inference/rlhf_colocate.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates how to co-locate a vLLM inference worker and training
|
||||
actors on the same set of GPUs for reinforcement learning from human feedback
|
||||
(RLHF) workloads.
|
||||
|
||||
Ray serves as the distributed execution framework in this example. Ray
|
||||
placement groups allocate both training actors and vLLM workers to the
|
||||
same GPU bundles, enabling fast, in-GPU communication between the two
|
||||
components.
|
||||
|
||||
The script shows how to do the following:
|
||||
|
||||
* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and
|
||||
`VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired
|
||||
devices.
|
||||
* Exchange tensors between processes by means of CUDA inter-process
|
||||
communication (IPC). CUDA IPC sidesteps NCCL limitations that occur
|
||||
when multiple processes share a single GPU.
|
||||
|
||||
Note that this example assumes a single-node cluster with four GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects exclusive use of the GPUs during
|
||||
its initialization for memory profiling. Residual GPU activity interferes
|
||||
with vLLM memory profiling and causes unexpected behavior.
|
||||
|
||||
Learn more about Ray placement groups:
|
||||
https://docs.ray.io/en/latest/placement-groups.html
|
||||
"""
|
||||
|
||||
import gc
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import zmq
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution.
|
||||
|
||||
The constructor sets environment variables that allow multiple vLLM
|
||||
workers to share a single physical GPU and that encode the bundle
|
||||
indices assigned by the placement group.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments forwarded to `vllm.LLM`.
|
||||
bundle_indices (list[int]): Placement-group bundle indices
|
||||
assigned to this worker.
|
||||
**kwargs: Keyword arguments forwarded to `vllm.LLM`.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, bundle_indices: list[int], **kwargs):
|
||||
# Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable
|
||||
# so that vLLM can its own device placement inside the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
# Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
|
||||
print(f"creating LLM with bundle_indices={bundle_indices}")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class RayTrainingActor:
|
||||
"""Training actor that hosts a Facebook OPT-125M model from Hugging Face.
|
||||
|
||||
The model is loaded onto the first GPU assigned to this actor, and expose
|
||||
the CUDA IPC handles so that colocated vLLM workers can map tensors
|
||||
directly.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
self.model.to("cuda:0")
|
||||
# Zero out all the parameters.
|
||||
for name, p in self.model.named_parameters():
|
||||
p.data.zero_()
|
||||
torch.cuda.synchronize()
|
||||
# The argument for `get_device_uuid` is the index of the GPU in the
|
||||
# list of visible devices.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(0)
|
||||
self.zmq_context = zmq.Context()
|
||||
self.zmq_address_counter = 0
|
||||
self.zmq_handle = None
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
return self.device_uuid
|
||||
|
||||
def get_zmq_handles(self) -> dict[str, str]:
|
||||
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
|
||||
self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
|
||||
self.zmq_address_counter += 1
|
||||
return {self.device_uuid: self.zmq_handle}
|
||||
|
||||
def update_weights(self):
|
||||
# align size to avoid misaligned address
|
||||
align_size = 256
|
||||
|
||||
def get_size(p: torch.Tensor) -> int:
|
||||
return (p.nbytes + align_size - 1) // align_size * align_size
|
||||
|
||||
named_parameters: dict[str, torch.nn.Parameter] = dict(
|
||||
self.model.named_parameters()
|
||||
)
|
||||
max_tensor_size = max(get_size(p) for p in named_parameters.values())
|
||||
# use max_tensor_size * 2 as buffer size
|
||||
buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
|
||||
s = self.zmq_context.socket(zmq.REQ)
|
||||
s.bind(self.zmq_handle)
|
||||
handle = reduce_tensor(buffer)
|
||||
|
||||
offset = 0
|
||||
buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
|
||||
named_tensors: list[dict] = []
|
||||
real_tensors: list[torch.Tensor] = []
|
||||
for name, p in named_parameters.items():
|
||||
size = get_size(p)
|
||||
if offset + size > buffer.numel():
|
||||
buckets.append((named_tensors, real_tensors))
|
||||
named_tensors, real_tensors = [], []
|
||||
offset = 0
|
||||
# assume tensors are contiguous
|
||||
named_tensors.append(
|
||||
{"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
|
||||
)
|
||||
real_tensors.append(p)
|
||||
offset += size
|
||||
if named_tensors:
|
||||
buckets.append((named_tensors, real_tensors))
|
||||
s.send_pyobj(handle)
|
||||
s.recv()
|
||||
for named_tensors, real_tensors in buckets:
|
||||
offset = 0
|
||||
for p in real_tensors:
|
||||
buffer[offset : offset + p.nbytes].data.copy_(
|
||||
p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
|
||||
)
|
||||
offset += get_size(p)
|
||||
torch.cuda.synchronize()
|
||||
s.send_pyobj(named_tensors)
|
||||
s.recv()
|
||||
s.send_pyobj(None)
|
||||
s.recv()
|
||||
s.close()
|
||||
del buffer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Ray manages four GPUs.
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
|
||||
ray.init()
|
||||
|
||||
# Co-locate vLLM instances and training actors on the same set of GPUs:
|
||||
# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
|
||||
# (tensor parallelism = 2).
|
||||
# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
|
||||
# (tensor parallelism = 2).
|
||||
|
||||
pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
|
||||
ray.get(pg.ready())
|
||||
print(f"placement group has bundles {pg.bundle_specs=}")
|
||||
|
||||
training_actors = []
|
||||
training_actor_device_ids = []
|
||||
inference_engines = []
|
||||
inference_engine_device_ids = []
|
||||
|
||||
for bundle_index in [0, 1, 2, 3]:
|
||||
training_actor = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0.4,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_index,
|
||||
),
|
||||
)(RayTrainingActor).remote()
|
||||
training_actors.append(training_actor)
|
||||
|
||||
for bundle_index, training_actor in enumerate(training_actors):
|
||||
device_id = ray.get(training_actor.report_device_id.remote())
|
||||
print(f"training actor {bundle_index} is on {device_id}")
|
||||
training_actor_device_ids.append(device_id)
|
||||
|
||||
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
|
||||
# Use the following syntax instead of the @ray.remote decorator so that
|
||||
# the placement group is customized for each bundle.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_capture_child_tasks=True,
|
||||
),
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
gpu_memory_utilization=0.4,
|
||||
bundle_indices=bundle_indices,
|
||||
)
|
||||
inference_engines.append(llm)
|
||||
# Do not call any method on the inference engine at this point; the call
|
||||
# blocks until the vLLM instance finishes initialization.
|
||||
|
||||
for i, llm in enumerate(inference_engines):
|
||||
inference_engine_device_ids.append(
|
||||
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
|
||||
)
|
||||
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
|
||||
|
||||
# Verify placement: the first two training actors share the same GPUs as
|
||||
# the first inference engine.
|
||||
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
|
||||
# Verify placement: the last two training actors share the same GPUs as
|
||||
# the second inference engine.
|
||||
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
|
||||
|
||||
print("Gather all the ZMQ handles from the training actors.")
|
||||
zmq_handles = {}
|
||||
for actor in training_actors:
|
||||
zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))
|
||||
|
||||
print(f"ZMQ handles: {zmq_handles}")
|
||||
|
||||
print("Update the weights of the inference engines.")
|
||||
ray.get(
|
||||
[actor.update_weights.remote() for actor in training_actors]
|
||||
+ [
|
||||
llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
|
||||
for llm in inference_engines
|
||||
]
|
||||
)
|
||||
|
||||
print("Check if the weights are updated.")
|
||||
for llm in inference_engines:
|
||||
assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
|
||||
162
examples/offline_inference/rlhf_online_quant.py
Normal file
162
examples/offline_inference/rlhf_online_quant.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
|
||||
|
||||
The script separates training and inference workloads onto distinct GPUs
|
||||
so that Ray can manage process placement and inter-process communication.
|
||||
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
|
||||
tensor-parallel vLLM inference engine occupies GPU 1–2.
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Load the training model on GPU 0.
|
||||
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
|
||||
and Ray placement groups.
|
||||
* Generate text from a list of prompts using the inference engine.
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using a Ray collective RPC group. Note that
|
||||
for demonstration purposes we simply zero out the weights.
|
||||
|
||||
For a production-ready implementation that supports multiple training and
|
||||
inference replicas, see the OpenRLHF framework:
|
||||
https://github.com/OpenRLHF/OpenRLHF
|
||||
|
||||
This example assumes a single-node cluster with three GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||
causes unexpected behavior.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from rlhf_utils import stateless_init_process_group
|
||||
from torchao.core.config import config_to_dict
|
||||
from torchao.quantization import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
PerRow,
|
||||
)
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||
# so that vLLM can manage its own device placement within the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
train_model.to("cuda:0")
|
||||
|
||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||
# be placed on GPUs 1 and 2.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
||||
ray.init()
|
||||
|
||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||
# Learn more about Ray placement groups:
|
||||
# https://docs.ray.io/en/latest/ray-core/scheduling/placement-group.html
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg_inference,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=0,
|
||||
)
|
||||
|
||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||
# start-up latency.
|
||||
|
||||
# generate torchao quantization config for RL rollout
|
||||
# see https://github.com/vllm-project/vllm/pull/23014 for instructions to
|
||||
# use serialized config files instead of passing around json string
|
||||
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
|
||||
|
||||
json_str = json.dumps(config_to_dict(config))
|
||||
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
scheduling_strategy=scheduling_inference,
|
||||
)(MyLLM).remote(
|
||||
model="facebook/opt-125m",
|
||||
hf_overrides={"quantization_config_dict_json": json_str},
|
||||
enforce_eager=True,
|
||||
worker_extension_cls="rlhf_utils.WorkerExtension",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="ray",
|
||||
)
|
||||
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address = get_ip()
|
||||
master_port = get_open_port()
|
||||
|
||||
handle = llm.collective_rpc.remote(
|
||||
"init_weight_update_group", args=(master_address, master_port, 1, 3)
|
||||
)
|
||||
|
||||
model_update_group = stateless_init_process_group(
|
||||
master_address, master_port, 0, 3, torch.device("cuda:0")
|
||||
)
|
||||
ray.get(handle)
|
||||
|
||||
# Simulate a training step by zeroing out all model weights.
|
||||
# In a real RLHF training loop the weights would be updated using the gradient
|
||||
# from an RL objective such as PPO on a reward model.
|
||||
for name, p in train_model.named_parameters():
|
||||
p.data.zero_()
|
||||
|
||||
# Synchronize the updated weights to the inference engine.
|
||||
for name, p in train_model.named_parameters():
|
||||
dtype_name = str(p.dtype).split(".")[-1]
|
||||
handle = llm.collective_rpc.remote(
|
||||
"update_weight", args=(name, dtype_name, p.shape)
|
||||
)
|
||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||
ray.get(handle)
|
||||
|
||||
# Verify that the inference weights have been updated.
|
||||
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
||||
|
||||
# Generate text with the updated model. The output is expected to be nonsense
|
||||
# because the weights are zero.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
168
examples/offline_inference/rlhf_utils.py
Normal file
168
examples/offline_inference/rlhf_utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
from collections.abc import Callable
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
|
||||
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
|
||||
"""
|
||||
vLLM provides `StatelessProcessGroup` to create a process group
|
||||
without considering the global process group in torch.distributed.
|
||||
It is recommended to create `StatelessProcessGroup`, and then initialize
|
||||
the data-plane communication (NCCL) between external (train processes)
|
||||
and vLLM workers.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
|
||||
pg = StatelessProcessGroup.create(
|
||||
host=master_address, port=master_port, rank=rank, world_size=world_size
|
||||
)
|
||||
pynccl = PyNcclCommunicator(pg, device=device)
|
||||
return pynccl
|
||||
|
||||
|
||||
class WorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from.
|
||||
By defining an extension class, the code can work no matter what is
|
||||
the underlying worker class.
|
||||
|
||||
NOTE: we define this class in a separate module, and the main module
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def init_weight_update_group(
|
||||
self, master_address, master_port, rank_offset, world_size
|
||||
):
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
rank = get_world_group().rank + rank_offset
|
||||
self.model_update_group = stateless_init_process_group(
|
||||
master_address,
|
||||
master_port,
|
||||
rank,
|
||||
world_size,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def update_weight(self, name, dtype_name, shape):
|
||||
dtype = getattr(torch, dtype_name)
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(
|
||||
weight, src=0, stream=torch.cuda.current_stream()
|
||||
)
|
||||
|
||||
self.model_runner.model.load_weights(weights=[(name, weight)])
|
||||
|
||||
del weight
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
|
||||
|
||||
def rebuild_ipc(
|
||||
handle: tuple[Callable, tuple], device_id: int | None = None
|
||||
) -> torch.Tensor:
|
||||
func, args = handle
|
||||
list_args = list(args)
|
||||
if device_id is not None:
|
||||
# the key is to change device id to the current device id
|
||||
# in case two processes have different CUDA_VISIBLE_DEVICES
|
||||
list_args[6] = device_id
|
||||
buffer = func(*list_args)
|
||||
return buffer
|
||||
|
||||
|
||||
class FlattenedTensorMetadata(TypedDict):
|
||||
name: str
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
# specify the start offset of this tensor in shared ipc_buffer tensor
|
||||
offset: int
|
||||
|
||||
|
||||
class ColocateWorkerExtension:
|
||||
"""
|
||||
The class for vLLM's worker to inherit from, in the colocate setting.
|
||||
By defining an extension class, the code can work no matter what is
|
||||
the underlying worker class.
|
||||
|
||||
NOTE: we define this class in a separate module, and the main module
|
||||
should pass the full qualified name as `worker_extension_cls` argument.
|
||||
"""
|
||||
|
||||
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
|
||||
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||
|
||||
assert self.device is not None
|
||||
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
||||
self._zmq_ctx = zmq.Context()
|
||||
socket = self._zmq_ctx.socket(zmq.REP)
|
||||
socket.connect(zmq_handles[self.report_device_id()])
|
||||
buffer: torch.Tensor | None = None
|
||||
while True:
|
||||
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
|
||||
socket.recv_pyobj()
|
||||
)
|
||||
if payload is None:
|
||||
# means the update is done
|
||||
process_weights_after_loading(
|
||||
self.model_runner.model, self.model_config, self.device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
socket.send(b"")
|
||||
break
|
||||
if isinstance(payload, tuple):
|
||||
# an ipc handle that vLLM can use `func, args = handle`
|
||||
# and `func(*args)` to rebuild GPU tensor.
|
||||
buffer = rebuild_ipc(payload, self.device.index)
|
||||
assert buffer.dtype == torch.uint8
|
||||
socket.send(b"")
|
||||
continue
|
||||
assert isinstance(payload, list)
|
||||
assert buffer is not None
|
||||
weights = []
|
||||
for item in payload:
|
||||
shape = item["shape"]
|
||||
if isinstance(shape, (list, tuple)):
|
||||
shape = torch.Size(shape)
|
||||
assert isinstance(shape, torch.Size)
|
||||
dtype, offset = item["dtype"], item["offset"]
|
||||
size = dtype.itemsize * shape.numel()
|
||||
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
|
||||
weights.append((item["name"], tensor))
|
||||
self.model_runner.model.load_weights(weights=weights)
|
||||
del weights
|
||||
torch.cuda.synchronize()
|
||||
socket.send(b"")
|
||||
|
||||
socket.close()
|
||||
del buffer
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def report_device_id(self) -> str:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_uuid = current_platform.get_device_uuid(self.device.index)
|
||||
return self.device_uuid
|
||||
|
||||
def check_weights_changed(self):
|
||||
"""
|
||||
Check if the weights are updated to 0.
|
||||
"""
|
||||
weights_updated = True
|
||||
for name, p in self.model_runner.model.named_parameters():
|
||||
weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
|
||||
return weights_updated
|
||||
87
examples/offline_inference/save_sharded_state.py
Normal file
87
examples/offline_inference/save_sharded_state.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||
fast load path for large tensor-parallel models where each worker only needs to
|
||||
read its own shard rather than the entire checkpoint.
|
||||
|
||||
Example usage:
|
||||
|
||||
python save_sharded_state.py \
|
||||
--model /path/to/load \
|
||||
--quantization deepspeedfp \
|
||||
--tensor-parallel-size 8 \
|
||||
--output /path/to/save
|
||||
|
||||
Then, the model can be loaded with
|
||||
|
||||
llm = LLM(
|
||||
model="/path/to/save",
|
||||
load_format="sharded_state",
|
||||
quantization="deepspeedfp",
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.model_executor.model_loader import ShardedStateLoader
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True, type=str, help="path to output checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-pattern",
|
||||
type=str,
|
||||
default=ShardedStateLoader.DEFAULT_PATTERN,
|
||||
help="string pattern of saved filenames",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-file-size",
|
||||
type=int,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
if engine_args.enable_lora:
|
||||
raise ValueError("Saving with enable_lora=True is not supported!")
|
||||
model_path = engine_args.model
|
||||
if not Path(model_path).is_dir():
|
||||
raise ValueError("model path must be a local directory")
|
||||
# Create LLM instance from arguments
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
# Prepare output directory
|
||||
Path(args.output).mkdir(exist_ok=True)
|
||||
# Dump worker states to output directory
|
||||
|
||||
llm.llm_engine.engine_core.save_sharded_state(
|
||||
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
|
||||
)
|
||||
|
||||
# Copy metadata files to output directory
|
||||
for file in os.listdir(model_path):
|
||||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||
if os.path.isdir(os.path.join(model_path, file)):
|
||||
shutil.copytree(
|
||||
os.path.join(model_path, file), os.path.join(args.output, file)
|
||||
)
|
||||
else:
|
||||
shutil.copy(os.path.join(model_path, file), args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
52
examples/offline_inference/simple_profiling.py
Normal file
52
examples/offline_inference/simple_profiling.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=1,
|
||||
profiler_config={
|
||||
"profiler": "torch",
|
||||
"torch_profiler_dir": "./vllm_profile",
|
||||
},
|
||||
)
|
||||
|
||||
llm.start_profile()
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput
|
||||
# objects that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
llm.stop_profile()
|
||||
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# Add a buffer to wait for profiler in the background process
|
||||
# (in case MP is on) to finish writing profiling output.
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, RequestOutput, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
def print_prompts_and_outputs(outputs: list[RequestOutput]) -> None:
|
||||
print("-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM without loading real weights
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
load_format="dummy",
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
print("\nOutputs do not make sense:")
|
||||
print_prompts_and_outputs(outputs)
|
||||
|
||||
# Update load format from `dummy` to `auto`
|
||||
llm.collective_rpc(
|
||||
"update_config", args=({"load_config": {"load_format": "auto"}},)
|
||||
)
|
||||
# Now reload real weights inplace
|
||||
llm.collective_rpc("reload_weights")
|
||||
|
||||
# Check outputs make sense
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
print("\nOutputs make sense after loading real weights:")
|
||||
print_prompts_and_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
234
examples/offline_inference/spec_decode.py
Normal file
234
examples/offline_inference/spec_decode.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.v1.metrics.reader import Counter, Vector
|
||||
|
||||
try:
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
|
||||
QUESTION = "What is the content of each image?"
|
||||
IMAGE_URLS = [
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
|
||||
]
|
||||
|
||||
|
||||
def get_custom_mm_prompts(num_prompts):
|
||||
prompts = []
|
||||
for url in IMAGE_URLS:
|
||||
prompts.append(
|
||||
[
|
||||
{"type": "image_url", "image_url": {"url": url}},
|
||||
{"type": "text", "text": QUESTION},
|
||||
]
|
||||
)
|
||||
if num_prompts > len(IMAGE_URLS):
|
||||
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
|
||||
|
||||
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument("--test", action="store_true")
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
parser.add_argument("--prompt-lookup-min", type=int, default=2)
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--enforce-eager", action="store_true")
|
||||
parser.add_argument("--enable-chunked-prefill", action="store_true")
|
||||
parser.add_argument("--max-model-len", type=int, default=16384)
|
||||
parser.add_argument("--temp", type=float, default=0)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=-1)
|
||||
parser.add_argument("--print-output", action="store_true")
|
||||
parser.add_argument("--output-len", type=int, default=256)
|
||||
parser.add_argument("--model-dir", type=str, default=None)
|
||||
parser.add_argument("--eagle-dir", type=str, default=None)
|
||||
parser.add_argument("--custom-mm-prompts", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
args.endpoint_type = "openai-chat"
|
||||
|
||||
model_dir = args.model_dir
|
||||
if args.model_dir is None:
|
||||
if args.custom_mm_prompts:
|
||||
raise ValueError(
|
||||
"custom_mm_prompts requires mm based models"
|
||||
"default llama3.1-8b-instruct is not mm based"
|
||||
"please specify model_dir to give a mm based model"
|
||||
)
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
args.custom_skip_chat_template = True
|
||||
|
||||
if not args.custom_mm_prompts:
|
||||
prompts = get_samples(args, tokenizer)
|
||||
# add_special_tokens is False to avoid adding bos twice
|
||||
# when using chat templates
|
||||
prompt_ids = [
|
||||
tokenizer.encode(prompt.prompt, add_special_tokens=False)
|
||||
for prompt in prompts
|
||||
]
|
||||
else:
|
||||
prompts = get_custom_mm_prompts(args.num_prompts)
|
||||
|
||||
if args.method == "eagle" or args.method == "eagle3":
|
||||
eagle_dir = args.eagle_dir
|
||||
if args.method == "eagle" and eagle_dir is None:
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
|
||||
elif args.method == "eagle3" and eagle_dir is None:
|
||||
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
elif args.method == "ngram":
|
||||
speculative_config = {
|
||||
"method": "ngram",
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method == "mtp":
|
||||
speculative_config = {
|
||||
"method": "mtp",
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
|
||||
llm = LLM(
|
||||
model=model_dir,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=args.tp,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
enforce_eager=args.enforce_eager,
|
||||
gpu_memory_utilization=0.9,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=args.max_model_len,
|
||||
limit_mm_per_prompt={"image": 5},
|
||||
disable_chunked_mm_input=True,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
if not args.custom_mm_prompts:
|
||||
outputs = llm.generate(
|
||||
[TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
else:
|
||||
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||
|
||||
# print the generated text
|
||||
if args.print_output:
|
||||
for output in outputs:
|
||||
print("-" * 50)
|
||||
print(f"prompt: {output.prompt}")
|
||||
print(f"generated text: {output.outputs[0].text}")
|
||||
print("-" * 50)
|
||||
|
||||
metrics = llm.get_metrics()
|
||||
|
||||
total_num_output_tokens = sum(
|
||||
len(output.outputs[0].token_ids) for output in outputs
|
||||
)
|
||||
num_drafts = 0
|
||||
num_draft_tokens = 0
|
||||
num_accepted_tokens = 0
|
||||
acceptance_counts = [0] * args.num_spec_tokens
|
||||
for metric in metrics:
|
||||
if metric.name == "vllm:spec_decode_num_drafts":
|
||||
assert isinstance(metric, Counter)
|
||||
num_drafts += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_draft_tokens":
|
||||
assert isinstance(metric, Counter)
|
||||
num_draft_tokens += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
|
||||
assert isinstance(metric, Counter)
|
||||
num_accepted_tokens += metric.value
|
||||
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
|
||||
assert isinstance(metric, Vector)
|
||||
for pos in range(len(metric.values)):
|
||||
acceptance_counts[pos] += metric.values[pos]
|
||||
|
||||
print("-" * 50)
|
||||
print(f"total_num_output_tokens: {total_num_output_tokens}")
|
||||
print(f"num_drafts: {num_drafts}")
|
||||
print(f"num_draft_tokens: {num_draft_tokens}")
|
||||
print(f"num_accepted_tokens: {num_accepted_tokens}")
|
||||
acceptance_length = 1 + (num_accepted_tokens / num_drafts) if num_drafts > 0 else 1
|
||||
print(f"mean acceptance length: {acceptance_length:.2f}")
|
||||
print("-" * 50)
|
||||
|
||||
# print acceptance at each token position
|
||||
for i in range(len(acceptance_counts)):
|
||||
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
|
||||
print(f"acceptance at token {i}: {acceptance_rate:.2f}")
|
||||
|
||||
return acceptance_length
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
acceptance_length = main(args)
|
||||
|
||||
if args.test:
|
||||
# takes ~30s to run on 1xH100
|
||||
assert args.method in ["eagle", "eagle3"]
|
||||
assert args.tp == 1
|
||||
assert args.num_spec_tokens == 3
|
||||
assert args.dataset_name == "hf"
|
||||
assert args.dataset_path == "philschmid/mt-bench"
|
||||
assert args.num_prompts == 80
|
||||
assert args.temp == 0
|
||||
assert args.top_p == 1.0
|
||||
assert args.top_k == -1
|
||||
assert args.enable_chunked_prefill
|
||||
|
||||
# check acceptance length is within 2% of expected value
|
||||
rtol = 0.02
|
||||
expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811
|
||||
|
||||
assert (
|
||||
acceptance_length <= (1 + rtol) * expected_acceptance_length
|
||||
and acceptance_length >= (1 - rtol) * expected_acceptance_length
|
||||
), (
|
||||
f"acceptance_length {acceptance_length} is not "
|
||||
f"within {rtol * 100}% of {expected_acceptance_length}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Test passed! Expected AL: "
|
||||
f"{expected_acceptance_length}, got {acceptance_length}"
|
||||
)
|
||||
113
examples/offline_inference/structured_outputs.py
Normal file
113
examples/offline_inference/structured_outputs.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates the example usage of structured outputs
|
||||
in vLLM. It shows how to apply different constraints such as choice,
|
||||
regex, json schema, and grammar to produce structured and formatted
|
||||
results based on specific prompts.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
MAX_TOKENS = 50
|
||||
|
||||
# Structured outputs by Choice (list of possible options)
|
||||
structured_outputs_params_choice = StructuredOutputsParams(
|
||||
choice=["Positive", "Negative"]
|
||||
)
|
||||
sampling_params_choice = SamplingParams(
|
||||
structured_outputs=structured_outputs_params_choice
|
||||
)
|
||||
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
|
||||
|
||||
# Structured outputs by Regex
|
||||
structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n")
|
||||
sampling_params_regex = SamplingParams(
|
||||
structured_outputs=structured_outputs_params_regex,
|
||||
stop=["\n"],
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
prompt_regex = (
|
||||
"Generate an email address for Alan Turing, who works in Enigma."
|
||||
"End in .com and new line. Example result:"
|
||||
"alan.turing@enigma.com\n"
|
||||
)
|
||||
|
||||
|
||||
# Structured outputs by JSON using Pydantic schema
|
||||
class CarType(str, Enum):
|
||||
sedan = "sedan"
|
||||
suv = "SUV"
|
||||
truck = "Truck"
|
||||
coupe = "Coupe"
|
||||
|
||||
|
||||
class CarDescription(BaseModel):
|
||||
brand: str
|
||||
model: str
|
||||
car_type: CarType
|
||||
|
||||
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
structured_outputs_params_json = StructuredOutputsParams(json=json_schema)
|
||||
sampling_params_json = SamplingParams(
|
||||
structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS
|
||||
)
|
||||
prompt_json = (
|
||||
"Generate a JSON with the brand, model and car_type of "
|
||||
"the most iconic car from the 90's"
|
||||
)
|
||||
|
||||
# Structured outputs by Grammar
|
||||
simplified_sql_grammar = """
|
||||
root ::= select_statement
|
||||
select_statement ::= "SELECT " column " from " table " where " condition
|
||||
column ::= "col_1 " | "col_2 "
|
||||
table ::= "table_1 " | "table_2 "
|
||||
condition ::= column "= " number
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
structured_outputs_params_grammar = StructuredOutputsParams(
|
||||
grammar=simplified_sql_grammar
|
||||
)
|
||||
sampling_params_grammar = SamplingParams(
|
||||
structured_outputs=structured_outputs_params_grammar,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
prompt_grammar = (
|
||||
"Generate an SQL query to show the 'username' and 'email' from the 'users' table."
|
||||
)
|
||||
|
||||
|
||||
def format_output(title: str, output: str):
|
||||
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
|
||||
|
||||
|
||||
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
|
||||
outputs = llm.generate(prompt, sampling_params=sampling_params)
|
||||
return outputs[0].outputs[0].text
|
||||
|
||||
|
||||
def main():
|
||||
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
|
||||
|
||||
choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
|
||||
format_output("Structured outputs by Choice", choice_output)
|
||||
|
||||
regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
|
||||
format_output("Structured outputs by Regex", regex_output)
|
||||
|
||||
json_output = generate_output(prompt_json, sampling_params_json, llm)
|
||||
format_output("Structured outputs by JSON", json_output)
|
||||
|
||||
grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
|
||||
format_output("Structured outputs by Grammar", grammar_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
151
examples/offline_inference/torchrun_dp_example.py
Normal file
151
examples/offline_inference/torchrun_dp_example.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
experimental support for data-parallel inference with torchrun
|
||||
Note the data load balancing and distribution is done out of the vllm engine,
|
||||
no internal lb supported in external_launcher mode.
|
||||
|
||||
To run this example:
|
||||
```bash
|
||||
$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py
|
||||
```
|
||||
|
||||
With custom parallelism settings:
|
||||
```bash
|
||||
$ torchrun --nproc-per-node=8 examples/offline_inference/torchrun_dp_example.py \
|
||||
--tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Data-parallel inference with torchrun"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pp-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ep",
|
||||
action="store_true",
|
||||
help="Enable expert parallel (default: False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="microsoft/Phi-mini-MoE-instruct",
|
||||
help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Maximum model length (default: 4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="GPU memory utilization (default: 0.6)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Random seed (default: 1)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# Create prompts, the same across all ranks
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create sampling parameters, the same across all ranks
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Use `distributed_executor_backend="external_launcher"` so that
|
||||
# this llm engine/instance only creates one worker.
|
||||
# it is important to set an explicit seed to make sure that
|
||||
# all ranks have the same random seed, so that sampling can be
|
||||
# deterministic across ranks.
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tensor_parallel_size=args.tp_size,
|
||||
data_parallel_size=args.dp_size,
|
||||
pipeline_parallel_size=args.pp_size,
|
||||
enable_expert_parallel=args.enable_ep,
|
||||
distributed_executor_backend="external_launcher",
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
|
||||
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size
|
||||
|
||||
prompts = [
|
||||
f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
|
||||
]
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(
|
||||
f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
|
||||
)
|
||||
|
||||
"""
|
||||
Further tips:
|
||||
|
||||
1. to communicate control messages across all ranks, use the cpu group,
|
||||
a PyTorch ProcessGroup with GLOO backend.
|
||||
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
torch_rank = dist.get_rank(group=cpu_group)
|
||||
if torch_rank == 0:
|
||||
# do something for rank 0, e.g. saving the results to disk.
|
||||
```
|
||||
|
||||
2. to communicate data across all ranks, use the model's device group,
|
||||
a PyTorch ProcessGroup with NCCL backend.
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
device_group = get_world_group().device_group
|
||||
```
|
||||
|
||||
3. to access the model directly in every rank, use the following code:
|
||||
```python
|
||||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
```
|
||||
"""
|
||||
76
examples/offline_inference/torchrun_example.py
Normal file
76
examples/offline_inference/torchrun_example.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
experimental support for tensor-parallel inference with torchrun,
|
||||
see https://github.com/vllm-project/vllm/issues/11400 for
|
||||
the motivation and use case for this example.
|
||||
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
|
||||
the argument 2 should match the `tensor_parallel_size` below.
|
||||
see `tests/distributed/test_torchrun_example.py` for the unit test.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Create prompts, the same across all ranks
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create sampling parameters, the same across all ranks
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Use `distributed_executor_backend="external_launcher"` so that
|
||||
# this llm engine/instance only creates one worker.
|
||||
# it is important to set an explicit seed to make sure that
|
||||
# all ranks have the same random seed, so that sampling can be
|
||||
# deterministic across ranks.
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B",
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2,
|
||||
distributed_executor_backend="external_launcher",
|
||||
max_model_len=32768,
|
||||
seed=1,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# all ranks will have the same outputs
|
||||
if dist.get_rank() == 0:
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
|
||||
print("-" * 50)
|
||||
"""
|
||||
Further tips:
|
||||
|
||||
1. to communicate control messages across all ranks, use the cpu group,
|
||||
a PyTorch ProcessGroup with GLOO backend.
|
||||
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
torch_rank = dist.get_rank(group=cpu_group)
|
||||
if torch_rank == 0:
|
||||
# do something for rank 0, e.g. saving the results to disk.
|
||||
```
|
||||
|
||||
2. to communicate data across all ranks, use the model's device group,
|
||||
a PyTorch ProcessGroup with NCCL backend.
|
||||
```python
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
device_group = get_world_group().device_group
|
||||
```
|
||||
|
||||
3. to access the model directly in every rank, use the following code:
|
||||
```python
|
||||
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
```
|
||||
"""
|
||||
2243
examples/offline_inference/vision_language.py
Executable file
2243
examples/offline_inference/vision_language.py
Executable file
File diff suppressed because it is too large
Load Diff
1542
examples/offline_inference/vision_language_multi_image.py
Executable file
1542
examples/offline_inference/vision_language_multi_image.py
Executable file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user