First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,163 @@
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
"""
import asyncio
import json
import ssl
from argparse import Namespace
from typing import Any, AsyncGenerator, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.launcher import serve_http
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
random_uuid)
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger("vllm.entrypoints.api_server")
TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
results_generator = iterate_with_cancellation(
results_generator, is_cancelled=request.is_disconnected)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
assert prompt is not None
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
try:
async for request_output in results_generator:
final_output = request_output
except asyncio.CancelledError:
return Response(status_code=499)
assert final_output is not None
prompt = final_output.prompt
assert prompt is not None
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
def build_app(args: Namespace) -> FastAPI:
global app
app.root_path = args.root_path
return app
async def init_app(
args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
app = build_app(args)
global engine
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER))
return app
async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
app = await init_app(args, llm_engine)
assert engine is not None
shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
await shutdown_task
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--ssl-ca-certs",
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
asyncio.run(run_server(args))

View File

@@ -0,0 +1,581 @@
import asyncio
import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
Mapping, Optional, Tuple, TypeVar, Union, cast)
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
logger = init_logger(__name__)
class AudioURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]
type: Required[Literal["audio_url"]]
"""The type of the content part."""
class CustomChatCompletionContentPartParam(TypedDict, total=False):
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
type: Required[str]
"""The type of the content part."""
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPartParam]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Optional[str]
"""The contents of the message"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
name: Optional[str]
"""The name of the function to call"""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ModalityStr = Literal["image", "audio", "video"]
_T = TypeVar("_T")
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
super().__init__()
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = []
@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index)
def _placeholder_str(self, modality: ModalityStr,
current_count: int) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
hf_config = self._model_config.hf_config
model_type = hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
return f"Picture {current_count}: <img></img>"
if model_type.startswith("llava"):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
if model_type == "qwen2_vl":
return "<|vision_start|><|video_pad|><|vision_end|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._items.append(item)
return self._placeholder_str(modality, current_count)
@abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items:
items = await asyncio.gather(*self._items)
return self._combine(items)
return None
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
class BaseMultiModalContentParser(ABC):
def __init__(self) -> None:
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
def _add_placeholder(self, placeholder: Optional[str]):
if placeholder:
self._placeholder_counts[placeholder] += 1
def mm_placeholder_counts(self) -> Dict[str, int]:
return dict(self._placeholder_counts)
@abstractmethod
def parse_image(self, image_url: str) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
return
elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError(
"the supplied chat template path doesn't exist")
elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n"
if not any(c in chat_template
for c in JINJA_CHARS) and not Path(chat_template).exists():
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!")
else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type")
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
return resolved_chat_template
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders: List[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
if placeholder_counts[placeholder] < 0:
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT
has_image = False
for part in parts:
part_type = part["type"]
if part_type == "text":
text = _TextParser(part)["text"]
texts.append(text)
elif part_type == "image_url":
image_url = _ImageParser(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")
mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"]
mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if keep_multimodal_content:
text_prompt = "\n".join(texts)
role_content = [{'type': 'text', 'text': text_prompt}]
if has_image:
role_content = [{'type': 'image'}] + role_content
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
if content is None:
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
result = _parse_chat_message_content_parts(
role,
content, # type: ignore
mm_tracker,
)
for result_msg in result:
if role == 'assistant':
parsed_msg = _AssistantParser(message)
if "tool_calls" in parsed_msg:
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
if "tool_call_id" in parsed_msg:
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]
return result
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads(
item["function"]["arguments"])
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: List[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data()
def parse_chat_messages_futures(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data()
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
**kwargs: Any,
) -> List[int]:
if chat_template is not None:
logger.warning(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return tokenizer.apply_chat_template(
messages=messages,
**kwargs,
)

View File

@@ -0,0 +1,103 @@
import asyncio
import signal
from http import HTTPStatus
from typing import Any
import uvicorn
from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
logger = init_logger(__name__)
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
path = getattr(route, "path", None)
if methods is None or path is None:
continue
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.serve())
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
async def dummy_shutdown() -> None:
pass
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
return dummy_shutdown()
except asyncio.CancelledError:
port = uvicorn_kwargs["port"]
process = find_process_using_port(port)
if process is not None:
logger.debug(
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""
@app.exception_handler(RuntimeError)
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError)
async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("MQLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

909
vllm/entrypoints/llm.py Normal file
View File

@@ -0,0 +1,909 @@
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
Union, cast, overload)
from tqdm import tqdm
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer. Expect valid prompt_token_ids and None for prompt
from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data
transfer for every forward pass.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See ParallelConfig
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`)
Note:
This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
"""
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
cls.DEPRECATE_LEGACY = True
yield
cls.DEPRECATE_LEGACY = False
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
mm_processor_kwargs=mm_processor_kwargs,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def get_tokenizer(self) -> AnyTokenizer:
return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer
def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if tokenizer.__class__.__name__.startswith("Cached"):
tokenizer_group.tokenizer = tokenizer
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
@overload # LEGACY: single (prompt + optional token ids)
def generate(
self,
prompts: str,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def generate(
self,
prompts: List[str],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def generate(
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def generate(
self,
prompts: Optional[List[str]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def generate(
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@overload
def generate(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@deprecate_kwargs(
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'prompts' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
priority: Optional[List[int]] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
Returns:
A list of ``RequestOutput`` objects containing the
generated completions in the same order as the input prompts.
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {guided_options_request}")
guided_options_request = GuidedDecodingRequest(
**guided_options_request)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request,
priority=priority)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
def beam_search(
self,
prompts: List[Union[str, List[int]]],
params: BeamSearchParams,
) -> List[BeamSearchOutput]:
"""
Generate sequences using beam search.
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
instances: List[BeamSearchInstance] = []
for prompt in prompts:
prompt_tokens = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
instances.append(BeamSearchInstance(prompt_tokens))
for _ in range(max_tokens):
all_beams: List[BeamSearchSequence] = list(
sum((instance.beams for instance in instances), []))
pos = [0] + list(
itertools.accumulate(
len(instance.beams) for instance in instances))
instance_start_and_end: List[Tuple[int, int]] = list(
zip(pos[:-1], pos[1:]))
if len(all_beams) == 0:
break
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
# only runs for one step
# we don't need to use tqdm here
output = self.generate(prompts_batch,
sampling_params=beam_search_params,
use_tqdm=False)
for (start, end), instance in zip(instance_start_and_end,
instances):
instance_new_beams = []
for i in range(start, end):
current_beam = all_beams[i]
result = output[i]
if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=sort_beams_key,
reverse=True)
instance.beams = sorted_beams[:beam_width]
outputs = []
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed,
key=sort_beams_key,
reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens)
outputs.append(BeamSearchOutput(sequences=best_beams))
return outputs
def chat(
self,
messages: Union[List[ChatCompletionMessageParam],
List[List[ChatCompletionMessageParam]]],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> List[RequestOutput]:
"""
Generate responses for a chat conversation.
The chat conversation is converted into a text prompt using the
tokenizer and calls the :meth:`generate` method to generate the
responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
list_of_messages: List[List[ChatCompletionMessageParam]]
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is List[List[...]]
list_of_messages = cast(List[List[ChatCompletionMessageParam]],
messages)
else:
# messages is List[...]
list_of_messages = [
cast(List[ChatCompletionMessageParam], messages)
]
prompts: List[Union[TokensPrompt, TextPrompt]] = []
for msgs in list_of_messages:
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer)
prompt_data: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template(
tokenizer,
messages=msgs,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
else:
prompt_data = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
prompt: Union[TokensPrompt, TextPrompt]
if is_list_of(prompt_data, int):
prompt = TokensPrompt(prompt_token_ids=prompt_data)
else:
prompt = TextPrompt(prompt=prompt_data)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs
prompts.append(prompt)
return self.generate(
prompts,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
@overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
prompts: str,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def encode(
self,
prompts: List[str],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def encode(
self,
prompts: Optional[str] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def encode(
self,
prompts: Optional[List[str]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def encode(
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@deprecate_kwargs(
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'prompts' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
Note:
Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
def start_profile(self) -> None:
self.llm_engine.start_profile()
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
# LEGACY
def _convert_v1_inputs(
self,
prompts: Optional[Union[str, List[str]]],
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
):
# skip_tokenizer_init is now checked in engine
if prompts is not None:
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
if prompt_token_ids is not None:
prompt_token_ids = [
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
]
num_requests = None
if prompts is not None:
num_requests = len(prompts)
if prompt_token_ids is not None:
if (num_requests is not None
and num_requests != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
num_requests = len(prompt_token_ids)
if num_requests is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
parsed_prompts: List[PromptType] = []
for i in range(num_requests):
item: PromptType
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
parsed_prompts.append(item)
return parsed_prompts
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None:
if guided_options is not None:
warnings.warn(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead",
DeprecationWarning,
stacklevel=2,
)
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
num_requests = len(prompts)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if isinstance(lora_request,
list) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")
for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_params(sp, guided_options)
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, prompt in enumerate(prompts):
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0,
)
def _add_request(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)
def _add_guided_params(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options is None:
return params
if params.guided_decoding is not None:
raise ValueError("Cannot set both guided_options_request and"
"params.guided_decoding.")
params.guided_decoding = GuidedDecodingParams(
json=guided_options.guided_json,
regex=guided_options.guided_regex,
choice=guided_options.guided_choice,
grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
return params
def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(
total=num_requests,
desc="Processed prompts",
dynamic_ncols=True,
postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
)
# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs)
out_spd = (total_out_toks /
pbar.format_dict["elapsed"])
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()

View File

@@ -0,0 +1,42 @@
from typing import List, Optional, Union
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__)
class RequestLogger:
def __init__(self, *, max_log_len: Optional[int]) -> None:
super().__init__()
self.max_log_len = max_log_len
def log_inputs(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if prompt is not None:
prompt = prompt[:max_log_len]
if prompt_token_ids is not None:
prompt_token_ids = prompt_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id,
prompt, params, prompt_token_ids, lora_request,
prompt_adapter_request)

View File

View File

@@ -0,0 +1,585 @@
import asyncio
import importlib
import inspect
import multiprocessing
import os
import re
import signal
import socket
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.datastructures import State
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
LoadLoraAdapterRequest,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10.)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_args=engine_args,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
ipc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
try:
while True:
try:
await mp_engine_client.setup()
break
except TimeoutError:
if not engine_process.is_alive():
raise RuntimeError(
"Engine process failed to start") from None
yield mp_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
engine_process.terminate()
# Close all open connections to the backend
mp_engine_client.close()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter()
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
return request.app.state.openai_serving_completion
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health")
async def health(raw_request: Request) -> Response:
"""Health check."""
await engine_client(raw_request).check_health()
return Response(status_code=200)
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
models = await completion(raw_request).show_available_models()
return JSONResponse(content=models.model_dump())
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await chat(raw_request).create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await completion(raw_request).create_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await embedding(raw_request).create_embedding(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
def build_app(args: Namespace) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
docs_url=None,
redoc_url=None,
lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
mount_metrics(app)
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
chat = app.state.openai_serving_chat
err = chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
root_path = "" if args.root_path is None else args.root_path
if request.method == "OPTIONS":
return await call_next(request)
if not request.url.path.startswith(f"{root_path}/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
return app
def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
base_model_paths,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
)
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", args.port))
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
fd=sock.fileno(),
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))

View File

@@ -0,0 +1,252 @@
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
lora_list: List[LoRAModulePath] = []
for item in values:
if item in [None, '']: # Skip if item is None or empty string
continue
if '=' in item and ',' not in item: # Old format: name=path
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(
f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list")
adapter_list: List[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=nullable_str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=nullable_str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{\"name\": \"name\", \"local_path\": \"path\", "
"\"base_model_name\": \"id\"}'")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=nullable_str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=nullable_str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=nullable_str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=nullable_str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified.")
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
default=False,
help=
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
"--tool-call-parser",
type=str,
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument(
"--disable-fastapi-docs",
action='store_true',
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)

View File

@@ -0,0 +1,86 @@
from functools import lru_cache, partial
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
import torch
from vllm.sampling_params import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
class AllowedTokenIdsLogitsProcessor:
"""Logits processor for constraining generated tokens to a
specific set of token ids."""
def __init__(self, allowed_ids: Iterable[int]):
self.allowed_ids: Optional[List[int]] = list(allowed_ids)
self.mask: Optional[torch.Tensor] = None
def __call__(self, token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
if self.mask is None:
self.mask = torch.ones((logits.shape[-1], ),
dtype=torch.bool,
device=logits.device)
self.mask[self.allowed_ids] = False
self.allowed_ids = None
logits.masked_fill_(self.mask, float("-inf"))
return logits
@lru_cache(maxsize=32)
def _get_allowed_token_ids_logits_processor(
allowed_token_ids: FrozenSet[int],
vocab_size: int,
) -> LogitsProcessor:
if not allowed_token_ids:
raise ValueError("Empty allowed_token_ids provided")
if not all(0 <= tid < vocab_size for tid in allowed_token_ids):
raise ValueError("allowed_token_ids contains "
"out-of-vocab token id")
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
def logit_bias_logits_processor(
logit_bias: Dict[int, float],
token_ids: List[int],
logits: torch.Tensor,
) -> torch.Tensor:
for token_id, bias in logit_bias.items():
logits[token_id] += bias
return logits
def get_logits_processors(
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
allowed_token_ids: Optional[List[int]],
tokenizer: AnyTokenizer,
) -> List[LogitsProcessor]:
logits_processors: List[LogitsProcessor] = []
if logit_bias:
try:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias: Dict[int, float] = {
int(token_id): min(100.0, max(-100.0, bias))
for token_id, bias in logit_bias.items()
}
except ValueError as exc:
raise ValueError(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc
# Check if token_id is within the vocab size
for token_id, bias in clamped_logit_bias.items():
if token_id < 0 or token_id >= tokenizer.vocab_size:
raise ValueError(f"token_id {token_id} in logit_bias contains "
"out-of-vocab token id")
logits_processors.append(
partial(logit_bias_logits_processor, clamped_logit_bias))
if allowed_token_ids is not None:
logits_processors.append(
_get_allowed_token_ids_logits_processor(
frozenset(allowed_token_ids), tokenizer.vocab_size))
return logits_processors

View File

@@ -0,0 +1,992 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionContentPartParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
_LONG_INFO: Union["torch.iinfo", Namespace]
try:
from sphinx.ext.autodoc.mock import _MockModule
if isinstance(torch, _MockModule):
_LONG_INFO = _MOCK_LONG_INFO
else:
_LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
_LONG_INFO = torch.iinfo(torch.long)
assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
tool_calls: Optional[List[dict]]
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
class ErrorResponse(OpenAIBaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: bool = False
class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
max_model_len: Optional[int] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(OpenAIBaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class RequestResponseMetadata(BaseModel):
request_id: str
final_usage_info: Optional[UsageInfo] = None
class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str
description: Optional[str] = None
# schema is the field in openai but that causes conflicts with pydantic so
# instead use json_schema with an alias
json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
strict: Optional[bool] = None
class ResponseFormat(OpenAIBaseModel):
# type must be "json_schema", "json_object" or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = True
class FunctionDefinition(OpenAIBaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
# NOTE this will be ignored by VLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: bool = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."),
)
documents: Optional[List[Dict[str, str]]] = Field(
default=None,
description=
("A list of dicts representing documents that will be accessible to "
"the model if it is performing RAG (retrieval-augmented generation)."
" If the template does not support RAG, this argument will have no "
"effect. We recommend that each document should be a dict containing "
"\"title\" and \"text\" keys."),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
guided_json_object = None
if (self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=prompt_logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias)
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None:
return None
# user has chosen to use a named tool
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = self.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in self.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters
return None
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when `stream=True`.")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
if isinstance(data, ValueError):
raise data
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
# you can only use one kind of guided decoding
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice",
"none") not in ("none", "auto"):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@model_validator(mode="before")
@classmethod
def check_tool_usage(cls, data):
# if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice
if "tool_choice" not in data and data.get("tools"):
data["tool_choice"] = "auto"
# if "tool_choice" is specified -- validation
if "tool_choice" in data:
# ensure that if "tool choice" is specified, tools are present
if "tools" not in data or data["tools"] is None:
raise ValueError(
"When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if data["tool_choice"] != "auto" and not isinstance(
data["tool_choice"], dict):
raise ValueError(
"`tool_choice` must either be a named tool or \"auto\". "
"`tool_choice=\"none\" is not supported.")
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
if isinstance(data["tool_choice"], dict):
valid_tool = False
specified_function = data["tool_choice"]["function"]
if not specified_function:
raise ValueError(
"Incorrectly formatted `tool_choice`. Should be like "
"`{\"type\": \"function\","
" \"function\": {\"name\": \"my_function\"}}`")
specified_function_name = specified_function["name"]
if not specified_function_name:
raise ValueError(
"Incorrectly formatted `tool_choice`. Should be like "
"`{\"type\": \"function\", "
"\"function\": {\"name\": \"my_function\"}}`")
for tool in data["tools"]:
if tool["function"]["name"] == specified_function_name:
valid_tool = True
break
if not valid_tool:
raise ValueError(
"The tool specified in `tool_choice` does not match any"
" of the specified `tools`")
return data
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[int] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description="If specified, the output will follow the JSON schema.",
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
echo_without_generation = self.echo and self.max_tokens == 0
guided_json_object = None
if (self.response_format is not None
and self.response_format.type == "json_object"):
guided_json_object = True
guided_decoding = GuidedDecodingParams.from_optional(
json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and prompt_logprobs > 0:
raise ValueError(
"`prompt_logprobs` are not available when `stream=True`.")
if prompt_logprobs < 0:
raise ValueError("`prompt_logprobs` must be a positive value.")
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.")
return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when `stream=True`.")
return data
class EmbeddingRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
input: Union[List[int], List[List[int]], str, List[str]]
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-embedding-extra-params
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
embedding: Union[List[float], str]
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[EmbeddingResponseData]
usage: UsageInfo
class FunctionCall(OpenAIBaseModel):
name: str
arguments: str
class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
function: FunctionCall
class DeltaFunctionCall(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
# a tool call delta where everything is optional
class DeltaToolCall(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
type: Literal["function"] = "function"
index: int
function: Optional[DeltaFunctionCall] = None
class ExtractedToolCallInformation(BaseModel):
# indicate if tools were called
tools_called: bool
# extracted tool calls
tool_calls: List[ToolCall]
# content - per OpenAI spec, content AND tool calls can be returned rarely
# But some models will do this intentionally
content: Optional[str] = None
class ChatMessage(OpenAIBaseModel):
role: str
content: Optional[str] = None
tool_calls: List[ToolCall] = Field(default_factory=list)
class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
class ChatCompletionLogProbs(OpenAIBaseModel):
content: Optional[List[ChatCompletionLogProbsContent]] = None
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[ChatCompletionLogProbs] = None
# per OpenAI spec this is the default
finish_reason: Optional[str] = "stop"
# not part of the OpenAI spec but included in vLLM for legacy reasons
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: List[DeltaToolCall] = Field(default_factory=list)
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class BatchRequestInput(OpenAIBaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameters of the request.
body: Union[ChatCompletionRequest, EmbeddingRequest]
class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response.
status_code: int = 200
# An unique identifier for the API request.
request_id: str
# The body of the response.
body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
class BatchRequestOutput(OpenAIBaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: Optional[BatchResponseData]
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]
class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
class TokenizeChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True)
continue_final_message: bool = Field(default=False)
add_special_tokens: bool = Field(default=False)
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
class TokenizeResponse(OpenAIBaseModel):
count: int
max_model_len: int
tokens: List[int]
class DetokenizeRequest(OpenAIBaseModel):
model: str
tokens: List[int]
class DetokenizeResponse(OpenAIBaseModel):
prompt: str
class LoadLoraAdapterRequest(BaseModel):
lora_name: str
lora_path: str
class UnloadLoraAdapterRequest(BaseModel):
lora_name: str
lora_int_id: Optional[int] = Field(default=None)

View File

@@ -0,0 +1,285 @@
import asyncio
from http import HTTPStatus
from io import StringIO
from typing import Awaitable, Callable, List, Optional
import aiohttp
import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
EmbeddingResponse, ErrorResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
def parse_args():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible batch runner.")
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help=
"The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.")
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.")
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=True`.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument("--enable-metrics",
action="store_true",
help="Enable Prometheus metrics")
parser.add_argument(
"--url",
type=str,
default="0.0.0.0",
help="URL to the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
return parser.parse_args()
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._pbar: Optional[tqdm] = None
def submitted(self):
self._total += 1
def completed(self):
if self._pbar:
self._pbar.update()
def pbar(self) -> tqdm:
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
self._pbar = tqdm(total=self._total,
unit="req",
desc="Running batch",
mininterval=5,
disable=not enable_tqdm,
bar_format=_BAR_FORMAT)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, "r", encoding="utf-8") as f:
return f.read()
async def write_file(path_or_url: str, data: str) -> None:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, \
session.put(path_or_url, data=data.encode("utf-8")):
pass
else:
# We should make this async, but as long as this is always run as a
# standalone program, blocking the event loop won't effect performance
# in this particular case.
with open(path_or_url, "w", encoding="utf-8") as f:
f.write(data)
def make_error_request_output(request: BatchRequestInput,
error_msg: str) -> BatchRequestOutput:
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=HTTPStatus.BAD_REQUEST,
request_id=f"vllm-batch-{random_uuid()}",
),
error=error_msg,
)
return batch_output
async def make_async_error_request_output(
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
return make_error_request_output(request, error_msg)
async def run_request(serving_engine_func: Callable,
request: BatchRequestInput,
tracker: BatchProgressTracker) -> BatchRequestOutput:
response = await serving_engine_func(request.body)
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
body=response, request_id=f"vllm-batch-{random_uuid()}"),
error=None,
)
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"vllm-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=response.code,
request_id=f"vllm-batch-{random_uuid()}"),
error=response,
)
else:
batch_output = make_error_request_output(
request, error_msg="Request must not be sent in stream mode")
tracker.completed()
return batch_output
async def main(args):
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
model_config = await engine.get_model_config()
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
# Create the openai serving objects.
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
base_model_paths,
args.response_role,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
base_model_paths,
request_logger=request_logger,
)
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding, request,
tracker))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg="Only /v1/chat/completions and "
"/v1/embeddings are supported in the batch endpoint.",
))
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
output_buffer = StringIO()
for response in responses:
print(response.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
await write_file(args.output_file, output_buffer.read().strip())
if __name__ == "__main__":
args = parse_args()
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
logger.info("args: %s", args)
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
logger.info("Prometheus metrics disabled")
asyncio.run(main(args))

View File

@@ -0,0 +1,891 @@
import asyncio
import json
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid
logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
self.response_role = response_role
self.use_tool_use_model_template = False
self.chat_template = load_chat_template(chat_template)
# set up tool use
self.enable_auto_tools: bool = enable_auto_tools
if self.enable_auto_tools:
logger.info(
"\"auto\" tool choice has been enabled please note that while"
" the parallel_tool_calls client option is preset for "
"compatibility reasons, it will be ignored.")
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
tool_parser)
except Exception as e:
raise TypeError("Error: --enable-auto-tool-choice requires "
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e
async def create_chat_completion(
self,
request: ChatCompletionRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine_client.get_tokenizer(lora_request)
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
logger.exception("Error in applying chat template from request")
return self.create_error_response(str(e))
try:
mm_data = await mm_data_future
except Exception as e:
logger.exception("Error in loading multi-modal data")
return self.create_error_response(str(e))
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
request_id = f"chat-{random_uuid()}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
if self.enable_auto_tools and self.tool_parser:
request = self.tool_parser(tokenizer).adjust_request(
request=request)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
assert prompt_inputs is not None
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
self._log_inputs(request_id,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (not is_tracing_enabled and raw_request
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
if isinstance(sampling_params, BeamSearchParams):
assert isinstance(self.engine_client,
(AsyncLLMEngine,
MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
)
else:
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
try:
return await self.chat_completion_full_generator(
request, result_generator, request_id, conversation, tokenizer,
request_metadata)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True
# Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
tool_choice_function_name = None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto = (
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))
all_previous_token_ids: Optional[List[List[int]]]
if tool_choice_auto:
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
else:
previous_texts, all_previous_token_ids = None, None
# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
tool_parsers: List[Optional[ToolParser]] = [
self.tool_parser(tokenizer)
] * num_choices
else:
tool_parsers = [None] * num_choices
except RuntimeError as e:
logger.error("Error in tool parser creation: %s", e)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
return
try:
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
tool_parser = tool_parsers[i]
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
role=role,
content="",
),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# if usage should be included
if (request.stream_options
and request.stream_options.include_usage):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the
# last message
if request.echo or request.continue_final_message:
last_msg_content: str = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""
if last_msg_content:
for i in range(num_choices):
choice_data = (
ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
logprobs=None,
finish_reason=None))
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(
exclude_unset=True)
yield f"data: {data}\n\n"
first_iteration = False
for output in res.outputs:
i = output.index
tool_parser = tool_parsers[i]
if finish_reason_sent[i]:
continue
if request.logprobs and request.top_logprobs is not None:
assert output.logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=output.token_ids,
top_logprobs=output.logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None
delta_text = output.text
delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice
if tool_choice_function_name:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
elif tool_choice_auto:
assert previous_texts is not None
assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids,
request=request))
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)
# set the previous values for the next iteration
previous_num_tokens[i] += len(output.token_ids)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
# wasn't ready to send a token, then
# get the next token without streaming a chunk
if delta_message is None:
continue
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# if the model is finished generating
else:
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
auto_tools_called = False
if tool_parser:
auto_tools_called = len(
tool_parser.prev_tool_call_arr) > 0
index = len(tool_parser.prev_tool_call_arr
) - 1 if auto_tools_called else 0
else:
index = 0
if self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output) and tool_parser:
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}))
# get what we've streamed so far for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[
index]
# check to see if there's anything left to stream
remaining_call = expected_call.replace(
actual_call, "", 1)
# set that as a delta message
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(index=index,
function=DeltaFunctionCall(
arguments=remaining_call).
model_dump(exclude_none=True))
])
# Send the finish response for each request.n only once
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
else:
chunk.usage = None
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[],
model=model_name,
usage=final_usage)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
num_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
total_tokens=num_prompt_tokens + num_completion_tokens)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
logger.error("error in chat completion stream generator: %s", e)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.base_model_paths[0].name
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
assert final_res is not None
choices: List[ChatCompletionResponseChoice] = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
out_logprobs = output.logprobs
if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_chat_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
)
else:
logprobs = None
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
message = ChatMessage(role=role, content=output.text)
# if the request uses tools and specified a tool choice
elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":
message = ChatMessage(role=role, content=output.text)
# handle when there are tools and tool choice is auto
elif request.tools and (
request.tool_choice == "auto"
or request.tool_choice is None) and self.enable_auto_tools \
and self.tool_parser:
try:
tool_parser = self.tool_parser(tokenizer)
except RuntimeError as e:
logger.error("Error in tool parser creation: %s", e)
return self.create_error_response(str(e))
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,
content=tool_call_info.content,
tool_calls=tool_call_info.tool_calls)
else:
# FOR NOW make it a chat message; we will have to detect
# the type to make it later.
message = ChatMessage(role=role, content=output.text)
# undetermined case that is still important to handle
else:
logger.error(
"Error in chat_completion_full_generator - cannot determine"
" if tools should be extracted. Returning a standard chat "
"completion.")
message = ChatMessage(role=role, content=output.text)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls" if auto_tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason)
choices.append(choice_data)
if request.echo or request.continue_final_message:
last_msg_content = ""
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""
for choice in choices:
full_message = last_msg_content + (choice.message.content
or "")
choice.message.content = full_message
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
if final_res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
request_metadata.final_usage_info = usage
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=final_res.prompt_logprobs,
)
return response
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(token=(token := self._get_decoded_token(
p[1],
p[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: AnyTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content: List[ChatCompletionLogProbsContent] = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
logprobs_content.append(
ChatCompletionLogProbsContent(
token=token,
bytes=list(token.encode("utf-8", errors="replace")),
))
else:
step_token = step_top_logprobs[token_id]
step_decoded = step_token.decoded_token
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self._get_decoded_token(
step_token,
token_id,
tokenizer,
self.return_tokens_as_token_ids,
),
logprob=max(step_token.logprob, -9999.0),
bytes=None if step_decoded is None else list(
step_decoded.encode("utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs,
num_output_top_logprobs,
tokenizer,
),
))
return ChatCompletionLogProbs(content=logprobs_content)
def _should_stream_with_auto_tool_parsing(self,
request: ChatCompletionRequest):
"""
Utility function to check if streamed tokens should go through the tool
call parser that was configured.
We only want to do this IF user-provided tools are set, a tool parser
is configured, "auto" tool choice is enabled, and the request's tool
choice field indicates that "auto" tool choice should be used.
"""
return (request.tools and self.tool_parser and self.enable_auto_tools
and request.tool_choice in ['auto', None])
def _should_check_for_unstreamed_tool_arg_tokens(
self,
delta_message: Optional[DeltaMessage],
output: CompletionOutput,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
This is only applicable when auto tool parsing is enabled, the delta
is a tool call with arguments.
"""
# yapf: disable
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
output.finish_reason is not None
and self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
)

View File

@@ -0,0 +1,554 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
))
for i, prompt_inputs in enumerate(prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
if isinstance(sampling_params, BeamSearchParams):
assert isinstance(self.engine_client,
(AsyncLLMEngine,
MQLLMEngineClient)), \
"Beam search is only supported with" \
"AsyncLLMEngine and MQLLMEngineClient."
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"],
request_id_item,
sampling_params,
)
else:
generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:
return self.completion_stream_generator(
request,
result_generator,
request_id,
created_time,
model_name,
num_prompts=len(prompts),
tokenizer=tokenizer,
request_metadata=request_metadata)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = prompts[i]["prompt"]
final_res_batch_checked = cast(List[RequestOutput],
final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response
async def completion_stream_generator(
self,
request: CompletionRequest,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
# Prompt details are excluded from later streamed outputs
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_token_ids is not None
assert prompt_text is not None
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
if request.logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=previous_text_lens[i],
)
else:
logprobs = None
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
)
])
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if request.stream_options.continuous_usage_stats:
chunk.usage = usage
else:
chunk.usage = None
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
if (request.stream_options
and request.stream_options.include_usage):
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=usage,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: List[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
out_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
request_metadata.final_usage_info = usage
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: AnyTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: List[int] = []
out_token_logprobs: List[Optional[float]] = []
out_tokens: List[str] = []
out_top_logprobs: List[Optional[Dict[str, float]]] = []
last_token_len = 0
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_token,
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
})
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)

View File

@@ -0,0 +1,203 @@
import asyncio
import base64
import time
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
def _get_embedding(
output: EmbeddingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.embedding
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
embedding_bytes = np.array(output.embedding, dtype="float32").tobytes()
return base64.b64encode(embedding_bytes).decode("utf-8")
assert_never(encoding_format)
def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
prompt_token_ids = final_res.prompt_token_ids
embedding = _get_embedding(final_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
class OpenAIServingEmbedding(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
if not self._enabled:
return self.create_error_response("Embedding API disabled")
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = request.model
request_id = f"embd-{random_uuid()}"
created_time = int(time.monotonic())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
pooling_params = request.to_pooling_params()
prompts = list(
self._tokenize_prompt_input_or_inputs(request, tokenizer,
request.input,
truncate_prompt_tokens))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported "
"for embedding models")
generator = self.engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params,
request_id_item,
lora_request=lora_request,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
final_res_batch[i] = res
for final_res in final_res_batch:
assert final_res is not None
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch)
response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")
return embedding_mode

View File

@@ -0,0 +1,487 @@
import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: List[int]
class OpenAIServing:
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.base_model_paths = base_model_paths
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self._is_model_supported(lora.base_model_name)
else self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()])
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if self._is_model_supported(request.model):
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapters(
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if self._is_model_supported(request.model):
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora, None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
else:
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
elif token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
def _tokenize_prompt_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
return_as_token_id: bool = False) -> str:
if return_as_token_id:
return f"token_id:{token_id}"
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)
async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return self.create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return self.create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret
lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)

View File

@@ -0,0 +1,157 @@
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template
# the list of commonly-used chat template names for HF named templates
hf_chat_templates: List[str] = ['default', 'tool_use']
self.chat_template = chat_template \
if chat_template in hf_chat_templates \
else load_chat_template(chat_template)
async def create_tokenize(
self,
request: TokenizeRequest,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
mm_data = await mm_data_future
if mm_data:
logger.warning(
"Multi-modal inputs are ignored during tokenization")
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
)
else:
prompt = request.prompt
self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self,
request: DetokenizeRequest,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text)

View File

@@ -0,0 +1,10 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .mistral_tool_parser import MistralToolParser
__all__ = [
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
]

View File

@@ -0,0 +1,161 @@
import importlib
import importlib.util
import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import is_list_of
logger = init_logger(__name__)
class ToolParser:
"""
Abstract ToolParser class that should not be used directly. Provided
properties and methods should be used in
derived classes.
"""
def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: List[Dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = []
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Static method that used to adjust the request parameters.
"""
return request
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Static because it's stateless.
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls has not been implemented!")
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting tool calls
from an incomplete response; for use when handling tool calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractToolParser.extract_tool_calls_streaming has not been "
"implemented!")
class ToolParserManager:
tool_parsers: Dict[str, Type] = {}
@classmethod
def get_tool_parser(cls, name) -> Type:
"""
Get tool parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod
def _register_module(cls,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = True) -> None:
if not issubclass(module, ToolParser):
raise TypeError(
f'module must be subclass of ToolParser, but got {type(module)}'
)
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.tool_parsers:
existed_module = cls.tool_parsers[name]
raise KeyError(f'{name} is already registered '
f'at {existed_module.__module__}')
cls.tool_parsers[name] = module
@classmethod
def register_module(
cls,
name: Optional[Union[str, List[str]]] = None,
force: bool = True,
module: Union[Type, None] = None) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# raise the error ahead of time
if not (name is None or isinstance(name, str)
or is_list_of(name, str)):
raise TypeError(
'name must be None, an instance of str, or a sequence of str, '
f'but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""
Import a user defined tool parser by the path of the tool parser define
file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
if spec is None or spec.loader is None:
logger.error("load %s from %s failed.", module_name, plugin_path)
return
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

View File

@@ -0,0 +1,338 @@
import json
import re
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module("hermes")
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
logger.error(
"Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
self.scratch_pad_regex = re.compile(
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = (
self.tool_call_regex.findall(model_output))
# load the JSON, and then use it to build the Function and
# Tool Call
raw_function_calls = [
json.loads(match[0] if match[0] else match[1])
for match in function_call_tuples
]
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"])))
for function_call in raw_function_calls
]
content = model_output[:model_output.
find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
except Exception as e:
logger.error("Error in extracting tool call from response %s",
e)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_call_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id)
prev_tool_end_count = previous_token_ids.count(
self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count):
logger.debug("Generating text content! skipping tool parsing.")
if delta_text != self.tool_call_end_token:
return DeltaMessage(content=delta_text)
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
# case -- we're starting a new tool call
if (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count > prev_tool_end_count):
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id], "")
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
self.streamed_args_for_tool[self.current_tool_id] \
+= diff
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
try:
current_tool_call = partial_json_parser.loads(
tool_call_portion or "{}",
flags) if tool_call_portion else None
logger.debug("Parsed tool call %s", current_tool_call)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = DeltaMessage(content=delta_text) \
if text_portion is not None else None
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug("Trying to parse current tool call with ID %s",
self.current_tool_id)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = (
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error("should be impossible to have arguments reset "
"mid-call. skipping streaming anything.")
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
logger.debug("finding %s in %s", delta_text,
cur_arguments_json)
# get the location where previous args differ from current
args_delta_start_loc = cur_arguments_json.index(delta_text) \
+ len(delta_text)
# use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= arguments_delta
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between\n%s", cur_args_json)
logger.debug("and\n%s", prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got argument diff %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= argument_diff
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = \
current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
return None # do not stream a delta. skip this token ID.

View File

@@ -0,0 +1,208 @@
import json
from typing import Dict, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["internlm"])
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because internlm use the special
# tokens to indicated the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request
def get_argments(self, obj):
if "parameters" in obj:
return obj.get("parameters")
elif "arguments" in obj:
return obj.get("arguments")
return None
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if '<|action_start|>' not in current_text:
self.position = len(current_text)
return DeltaMessage(content=delta_text)
# if the tool call is sended, return a empty delta message
# to make sure the finish_reason will be send correctly.
if self.current_tool_id > 0:
return DeltaMessage(content='')
last_pos = self.position
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
return None
new_delta = current_text[last_pos:]
text, action = new_delta.split('<|action_start|><|plugin|>')
if len(text) > 0:
self.position = self.position + len(text)
return DeltaMessage(content=text)
action = action.strip()
action = action.split('<|action_end|>'.strip())[0]
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
parsable_arr = action
# tool calls are generated in an object in inernlm2
# it's not support parallel tool calls
try:
tool_call_arr: Dict = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = tool_call_arr.get("name")
if function_name:
self.current_tool_id = self.current_tool_id + 1
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
self.streamed_args_for_tool.append("")
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.get_argments(
self.prev_tool_call_arr[self.current_tool_id])
cur_arguments = self.get_argments(tool_call_arr)
# not arguments generated
if not cur_arguments and not prev_arguments:
delta = None
# will never happen
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
# first time to get parameters
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(delta_text) +
len(delta_text)]
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
# both prev and cur parameters, send the increase parameters
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
self.prev_tool_call_arr = [tool_call_arr]
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
text = model_output
tools = request.tools
if '<|action_start|><|plugin|>' in text:
text, action = text.split('<|action_start|><|plugin|>')
action = action.split('<|action_end|>'.strip())[0]
action = action[action.find('{'):]
action_dict = json.loads(action)
name, parameters = action_dict['name'], json.dumps(
action_dict.get('parameters', action_dict.get('arguments',
{})))
if not tools or name not in [t.function.name for t in tools]:
ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)
tool_calls = [
ToolCall(
function=FunctionCall(name=name, arguments=parameters))
]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=text if len(text) > 0 else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)

View File

@@ -0,0 +1,277 @@
import json
import re
from json import JSONDecodeError, JSONDecoder
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
# partial_json_parser doesn't support extra data and
# JSONDecorder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str, flags):
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
else:
raise
def is_complete_json(input_str):
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
@ToolParserManager.register_module("llama3_json")
class Llama3JsonToolParser(ToolParser):
"""
Tool call parser for Llama 3.1 models intended for use with the
examples/tool_chat_template_llama.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "<|python_tag|>"
self.bot_token_id = tokenizer.encode(self.bot_token,
add_special_tokens=False)[0]
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
# case -- if a tool call token is not present, return a text response
if not (model_output.startswith(self.bot_token)
or model_output.startswith('{')):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
# load the JSON, and then use it to build the Function and
# Tool Call
dec = JSONDecoder()
function_call_arr = []
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if model_output.startswith(
self.bot_token) else 0
while start_idx < len(model_output):
(obj, end_idx) = dec.raw_decode(model_output[start_idx:])
start_idx += end_idx + len('; ')
function_call_arr.append(obj)
tool_calls: List[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"] \
if "arguments" in raw_function_call \
else raw_function_call["parameters"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
ret = ExtractedToolCallInformation(tools_called=True,
tool_calls=tool_calls,
content=None)
return ret
except Exception as e:
logger.error("Error in extracting tool call from response: %s", e)
print("ERROR", e)
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
if not (current_text.startswith(self.bot_token)
or current_text.startswith('{')):
return DeltaMessage(content=delta_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = len(self.bot_token) if current_text.startswith(
self.bot_token) else 0
while start_idx < len(current_text):
(obj,
end_idx) = partial_json_loads(current_text[start_idx:],
flags)
is_complete.append(
is_complete_json(current_text[start_idx:start_idx +
end_idx]))
start_idx += end_idx + len('; ')
# depending on the prompt Llama can use
# either arguments or parameters
if "parameters" in obj:
assert "arguments" not in obj, \
"model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
delta = None
if cur_arguments:
sent = len(
self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = find_common_prefix(
prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@@ -0,0 +1,306 @@
import json
import re
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral "
"model...")
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!")
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response. Requires
find-and-replacing single quotes with double quotes for JSON parsing,
make sure your tool call arguments don't ever include quotes!
"""
# case -- if a tool call token is not present, return a text response
if self.bot_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
# use a regex to find the tool call. remove the BOT token
# and make sure to replace single quotes with double quotes
raw_tool_call = self.tool_call_regex.findall(
model_output.replace(self.bot_token, ""))[0]
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"])))
for raw_function_call in function_call_arr
]
# get any content before the tool call
content = model_output.split(self.bot_token)[0]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if len(content) > 0 else None)
except Exception as e:
logger.error("Error in extracting tool call from response: %s", e)
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token not in current_text:
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if (self.bot_token_id in delta_token_ids
and len(delta_token_ids) == 1):
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[-1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: List[Dict] = partial_json_parser.loads(
parsable_arr, flags)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
# select as the current tool call the one we're on the state at
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > 0 else {}
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: Union[str, None] = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id],
"")
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("\'", "\"")
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset "
"mid-arguments")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
logger.debug("finding %s in %s", new_text,
cur_arguments_json)
arguments_delta = cur_arguments_json[:cur_arguments_json.
index(new_text) +
len(new_text)]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception as e:
logger.error("Error trying to handle streaming tool call: %s", e)
logger.debug(
"Skipping chunk as a result of tool streaming extraction "
"error")
return None

View File

@@ -0,0 +1,87 @@
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix = ''
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def find_common_suffix(s1: str, s2: str) -> str:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix = ''
min_length = min(len(s1), len(s2))
for i in range(1, min_length + 1):
if s1[-i] == s2[-i] and not s1[-i].isalnum():
suffix = s1[-i] + suffix
else:
break
return suffix
def extract_intermediate_diff(curr: str, old: str) -> str:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix = find_common_suffix(curr, old)
old = old[::-1].replace(suffix[::-1], '', 1)[::-1]
prefix = find_common_prefix(curr, old)
diff = curr
if len(suffix):
diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1]
if len(prefix):
# replace the prefix only once in case it's mirrored
diff = diff.replace(prefix, '', 1)
return diff
def find_all_indices(string, substring):
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices = []
index = -1
while True:
index = string.find(substring, index + 1)
if index == -1:
break
indices.append(index)
return indices